mirror of https://github.com/llvm/torch-mlir
[RefBackend] Use std.global_memref instead of homegrown thing
This vastly simplifies our code, allowing deleting multiple ops, simplifying multiple passes, and removing a whole pass. Now `refback` dialect is down to one op (refback.alloc_memref, which simplifies allocations to just take a shape instead of individual extents).pull/119/head
parent
6850295ec5
commit
5227d52c26
|
@ -20,21 +20,6 @@ class Refback_Op<string mnemonic, list<OpTrait> traits = []>
|
|||
: Op<Refback_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def Refback_GlobalOp : Refback_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to bufferization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -52,15 +37,4 @@ def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> {
|
|||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def Refback_GetGlobalMemrefOp : Refback_Op<"get_global_memref"> {
|
||||
let summary = "Obtain a memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a memref pointing at the given global.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
}
|
||||
|
||||
#endif // REFBACK_OPS
|
||||
|
|
|
@ -46,43 +46,6 @@ def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
|
|||
let assemblyFormat = "$pred `,` $msg attr-dict";
|
||||
}
|
||||
|
||||
def Refbackrt_GlobalOp : Refbackrt_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Refbackrt_GetGlobalOp : Refbackrt_Op<"get_global"> {
|
||||
let summary = "Obtain a rank-erased memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a rank-erased memref pointing at the given global.
|
||||
|
||||
TODO: As we define the runtime layer better, we should have fewer
|
||||
entry points that return memrefs, or at least have a clearer separation
|
||||
between the "memref world" and the "refbackrt world".
|
||||
Something like forming IREE dispatch regions seems to be the missing thing:
|
||||
- Everything inside the dispatch regions gets things marshaled from the
|
||||
runtime (flow/hal/refbackrt) layer to/from memrefs in a clear way.
|
||||
- Everything outside the dispatch regions purely uses the runtime
|
||||
(flow/hal/refbackrt) data structures.
|
||||
Globals should be one of the things that are purely runtime data structures,
|
||||
rather than using memrefs. For now, using memrefs is simpler though.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyUnrankedMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
}
|
||||
|
||||
def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
|
||||
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
||||
]> {
|
||||
|
|
|
@ -11,15 +11,6 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LowerConstantTensorsToMemref :
|
||||
Pass<"lower-constant-tensors-to-memref", "ModuleOp"> {
|
||||
let summary = "Lower std.constant of tensor type to memref";
|
||||
let description = [{
|
||||
This must be a module pass since it involves creating refback.global ops.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()";
|
||||
}
|
||||
|
||||
def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
|
||||
let summary = "Lower constructs requiring runtime support to `refbackrt`";
|
||||
let description = [{
|
||||
|
@ -30,7 +21,6 @@ def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
|
|||
|
||||
The constructs requiring runtime support are:
|
||||
- function signatures / module metadata
|
||||
- globals
|
||||
- error handling
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";
|
||||
|
|
|
@ -23,9 +23,6 @@ void registerRefBackendPasses();
|
|||
//
|
||||
// Pass summaries are in Passes.td.
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLowerConstantTensorsToMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();
|
||||
|
|
|
@ -15,49 +15,5 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::refback;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "refback.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid symbol";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (failed(
|
||||
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
|
||||
return op.emitError() << "inconsistent with shape of global";
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"
|
||||
|
|
|
@ -16,47 +16,6 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::refbackrt;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "refbackrt.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid refbackrt.global";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleMetadataOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -5,7 +5,6 @@ add_npcomp_library(NPCOMPRefBackend
|
|||
RefBackend.cpp
|
||||
LowerToLLVM.cpp
|
||||
LowerToRefbackrtABI.cpp
|
||||
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/RefBackend
|
||||
|
|
|
@ -56,22 +56,6 @@ static LLVMType getModuleDescriptorTy(MLIRContext *context) {
|
|||
});
|
||||
}
|
||||
|
||||
// Get the LLVMType for refbackrt::GlobalDescriptor.
|
||||
static LLVMType getGlobalDescriptorTy(MLIRContext *context) {
|
||||
return LLVMType::getStructTy(
|
||||
// std::int32_t numExtents;
|
||||
LLVMType::getIntNTy(context, 32),
|
||||
// std::int32_t *extents;
|
||||
LLVMType::getIntNTy(context, 32).getPointerTo(),
|
||||
// It is important that this struct member is a type-erased pointer
|
||||
// so that this type is "context-free" and can be created in conversion
|
||||
// patterns independently of the actual type of the data stored in the
|
||||
// buffer.
|
||||
//
|
||||
// void *data;
|
||||
LLVMType::getInt8PtrTy(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compiler runtime functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -122,35 +106,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class GetGlobalOpCompilerRuntimeLowering
|
||||
: public OpConversionPattern<refbackrt::GetGlobalOp> {
|
||||
public:
|
||||
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||
: OpConversionPattern<refbackrt::GetGlobalOp>(backingFunc.getContext()),
|
||||
backingFunc(backingFunc) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(refbackrt::GetGlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// It would be nice if we could use the constructor here that takes just the
|
||||
// global, but keeping track of the converted llvm.mlir.global op that gets
|
||||
// created from the refbackrt.global while conversion is going on is a
|
||||
// headache.
|
||||
//
|
||||
// Instead, we rely on the symbol name being the same and the result type
|
||||
// always being the same.
|
||||
auto globalAddr = rewriter.create<LLVM::AddressOfOp>(
|
||||
op.getLoc(),
|
||||
getGlobalDescriptorTy(rewriter.getContext()).getPointerTo(),
|
||||
op.globalAttr());
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc,
|
||||
ValueRange({globalAddr}));
|
||||
return success();
|
||||
}
|
||||
LLVM::LLVMFuncOp backingFunc;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
||||
OpBuilder &builder, Location loc) {
|
||||
// TODO: Deduplicate strings.
|
||||
|
@ -262,163 +217,7 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
|||
"from_memref", funcTy, builder, module.getLoc());
|
||||
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
|
||||
}
|
||||
|
||||
{
|
||||
// Hardcoding f32 is fine here, since unranked memref descriptors have
|
||||
// identical struct layout / ABI / contents regardless of the element type.
|
||||
auto mlirFunctionType = builder.getFunctionType(
|
||||
{getGlobalDescriptorTy(context).getPointerTo()},
|
||||
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
|
||||
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
||||
LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl(
|
||||
"get_global", funcTy, builder, module.getLoc());
|
||||
patterns.insert<GetGlobalOpCompilerRuntimeLowering>(backingFunc);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering for refbackrt.global
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerRefbackrtGlobalOp : public OpConversionPattern<refbackrt::GlobalOp> {
|
||||
public:
|
||||
explicit LowerRefbackrtGlobalOp(LLVMTypeConverter &typeConverter)
|
||||
: OpConversionPattern<refbackrt::GlobalOp>(&typeConverter.getContext()),
|
||||
typeConverter(typeConverter) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(refbackrt::GlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = rewriter.getContext();
|
||||
auto globalDescriptorTy = getGlobalDescriptorTy(context);
|
||||
|
||||
// Create the data buffer.
|
||||
auto dataBuffer = createGlobalForDenseElementsAttr(
|
||||
(Twine("__refbackrt_global_data_buffer_") + op.sym_name()).str(),
|
||||
op.value().cast<DenseElementsAttr>(), op, rewriter);
|
||||
|
||||
// Create the extents buffer.
|
||||
auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>(
|
||||
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
|
||||
[](int64_t i) -> int32_t { return i; })));
|
||||
auto extentsBuffer = createGlobalForDenseElementsAttr(
|
||||
(Twine("__refbackrt_global_extents_") + op.sym_name()).str(),
|
||||
extentsI32, op, rewriter);
|
||||
|
||||
// Create the GlobalDescriptor.
|
||||
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
|
||||
op.getLoc(), globalDescriptorTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.createBlock(&globalDescriptorGlobal.initializer());
|
||||
|
||||
// Create the body of the initializer.
|
||||
Value globalDescriptor =
|
||||
rewriter.create<LLVM::UndefOp>(op.getLoc(), globalDescriptorTy);
|
||||
auto updateDescriptor = [&](Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
globalDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op.getLoc(), globalDescriptor, value,
|
||||
/*position=*/rewriter.getI32ArrayAttr(position));
|
||||
};
|
||||
updateDescriptor(
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), LLVMType::getIntNTy(context, 32),
|
||||
rewriter.getI32IntegerAttr(
|
||||
op.value().getType().cast<ShapedType>().getRank())),
|
||||
{0});
|
||||
|
||||
// The global is actually an array, so we need to get a bare i32* pointer
|
||||
// type. We could do this with GEP but it would be more verbose.
|
||||
auto extentsBufferArrayAddress =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), extentsBuffer);
|
||||
auto extentsBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||
op.getLoc(), LLVMType::getIntNTy(context, 32).getPointerTo(),
|
||||
extentsBufferArrayAddress);
|
||||
updateDescriptor(extentsBufferAddress, {1});
|
||||
|
||||
auto dataBufferAddress =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), dataBuffer);
|
||||
auto typeErasedDataBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||
op.getLoc(), LLVMType::getInt8PtrTy(context), dataBufferAddress);
|
||||
updateDescriptor(typeErasedDataBufferAddress, {2});
|
||||
rewriter.create<LLVM::ReturnOp>(op.getLoc(), globalDescriptor);
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO: It feels like MLIR core should have better utilities for this.
|
||||
LLVM::GlobalOp createGlobalForDenseElementsAttr(
|
||||
StringRef symbolName, DenseElementsAttr elements, refbackrt::GlobalOp op,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto type = elements.getType().cast<ShapedType>();
|
||||
|
||||
// LLVM translation doesn't handle the case of zero-sized tensors, which can
|
||||
// happen e.g. for the number of extents of a rank-0 (i.e. scalar).
|
||||
//
|
||||
// We fake-up a size-1 DenseElementsAttr to use for creating the global.
|
||||
// That takes up binary space (one element instead of zero), but that seems
|
||||
// fine.
|
||||
//
|
||||
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||
if (type.getNumElements() == 0) {
|
||||
auto elementType = type.getElementType();
|
||||
Attribute singleElement;
|
||||
if (elementType.isIntOrIndex())
|
||||
singleElement = rewriter.getIntegerAttr(elementType, 0);
|
||||
else if (elementType.isa<FloatType>())
|
||||
singleElement = rewriter.getFloatAttr(elementType, 0);
|
||||
assert(singleElement &&
|
||||
"could not fake up an element for a zero element tensor");
|
||||
type = RankedTensorType::get({1}, elementType);
|
||||
elements =
|
||||
DenseElementsAttr::get(type, ArrayRef<Attribute>(singleElement));
|
||||
}
|
||||
|
||||
auto llvmType = getLLVMTypeForShapedType(type, op, rewriter);
|
||||
return rewriter.create<LLVM::GlobalOp>(
|
||||
op.getLoc(), llvmType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
|
||||
}
|
||||
|
||||
LLVMType getLLVMTypeForShapedType(ShapedType type, refbackrt::GlobalOp op,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto llvmType =
|
||||
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
|
||||
|
||||
// MLIR->LLVM lowering for globals requires non-scalar data types. So use a
|
||||
// dummy size-1 array for the scalar case.
|
||||
//
|
||||
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||
if (type.getRank() == 0)
|
||||
return LLVMType::getArrayTy(llvmType, 1);
|
||||
|
||||
if (!llvmType) {
|
||||
rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "cannot convert element type " << type.getElementType()
|
||||
<< " to an LLVM type";
|
||||
});
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Construct an LLVM nested array type for the tensor initializer.
|
||||
// tensor<f32> -> float
|
||||
// tensor<10xf32> -> [10 x float]
|
||||
// tensor<2x3xf32> -> [2 x [3 x float]]
|
||||
assert(type.hasStaticShape());
|
||||
auto shape = type.getShape();
|
||||
while (!shape.empty()) {
|
||||
llvmType = LLVMType::getArrayTy(llvmType, shape.back());
|
||||
shape = shape.drop_back();
|
||||
}
|
||||
return llvmType;
|
||||
}
|
||||
LLVMTypeConverter &typeConverter;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering for module metadata
|
||||
|
@ -706,7 +505,6 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
|||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
patterns.insert<LowerModuleMetadata>(context);
|
||||
patterns.insert<LowerRefbackrtGlobalOp>(converter);
|
||||
|
||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||
// lots of these patterns.
|
||||
|
|
|
@ -74,37 +74,6 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
|||
// Dialect conversion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerGlobalOp : public OpConversionPattern<refback::GlobalOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<refbackrt::GlobalOp>(op, op.sym_name(),
|
||||
op.value());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerGetGlobalMemrefOp
|
||||
: public OpConversionPattern<refback::GetGlobalMemrefOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto abiMemref = rewriter.create<refbackrt::GetGlobalOp>(
|
||||
op.getLoc(), getABIMemrefType(op.getType()), op.global());
|
||||
// Cast back to the original type.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerAssertOp : public OpConversionPattern<AssertOp> {
|
||||
public:
|
||||
|
@ -216,12 +185,6 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
|||
target.addDynamicallyLegalOp<ReturnOp>(
|
||||
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
patterns.insert<LowerGlobalOp>(context);
|
||||
target.addIllegalOp<refback::GlobalOp>();
|
||||
|
||||
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
||||
target.addIllegalOp<refback::GetGlobalMemrefOp>();
|
||||
|
||||
patterns.insert<LowerAssertOp>(context);
|
||||
target.addIllegalOp<AssertOp>();
|
||||
|
||||
|
|
|
@ -228,8 +228,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
|
||||
// Bufferize the TCP dialect.
|
||||
pm.addNestedPass<FuncOp>(createTCPBufferizePass());
|
||||
// Lower tensor-valued constants to refback.global.
|
||||
pm.addPass(createLowerConstantTensorsToMemrefPass());
|
||||
pm.addPass(createTensorConstantBufferizePass());
|
||||
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
|
||||
// We need to resolve this to std.alloc which takes individual extents.
|
||||
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());
|
||||
|
|
|
@ -48,13 +48,6 @@ struct ModuleDescriptor {
|
|||
FuncDescriptor *functionDescriptors;
|
||||
};
|
||||
|
||||
// Static data representing a global variable (together with its shape).
|
||||
struct GlobalDescriptor {
|
||||
std::int32_t numExtents;
|
||||
std::int32_t *extents;
|
||||
void *data;
|
||||
};
|
||||
|
||||
} // namespace refbackrt
|
||||
|
||||
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||
|
|
|
@ -114,11 +114,3 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank,
|
|||
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
||||
elementType, data);
|
||||
}
|
||||
|
||||
extern "C" UnrankedMemref
|
||||
__npcomp_compiler_rt_get_global(GlobalDescriptor *global) {
|
||||
auto *descriptor = MemrefDescriptor::create(
|
||||
ArrayRef<std::int32_t>(global->extents, global->numExtents),
|
||||
global->data);
|
||||
return UnrankedMemref{global->numExtents, descriptor};
|
||||
}
|
||||
|
|
|
@ -1,113 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerConstantTensorsToMemref
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// This class creates global ops for all tensor-valued constants in the program.
|
||||
// It creates them with pretty names and makes sure that duplicate globals
|
||||
// aren't created.
|
||||
class GlobalCreator {
|
||||
public:
|
||||
explicit GlobalCreator(ModuleOp module);
|
||||
refback::GlobalOp getGlobalFor(Attribute attr) {
|
||||
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
||||
return globals[attr];
|
||||
}
|
||||
|
||||
private:
|
||||
DenseMap<Attribute, refback::GlobalOp> globals;
|
||||
};
|
||||
|
||||
GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||
// Create a builder without an insertion point. We will insert using the
|
||||
// symbol table to guarantee unique names.
|
||||
OpBuilder globalBuilder(module.getContext());
|
||||
SymbolTable symbolTable(module);
|
||||
module.walk([&](ConstantOp op) {
|
||||
// We only want tensor constants for now.
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
// If we already have a global for this constant value, no need to do
|
||||
// anything else.
|
||||
auto it = globals.find(op.getValue());
|
||||
if (it != globals.end())
|
||||
return;
|
||||
|
||||
// Create a pretty name.
|
||||
SmallString<64> buf;
|
||||
llvm::raw_svector_ostream os(buf);
|
||||
interleave(type.getShape(), os, "x");
|
||||
os << "x" << type.getElementType();
|
||||
|
||||
auto global = globalBuilder.create<refback::GlobalOp>(
|
||||
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||
op.getValue().cast<ElementsAttr>());
|
||||
symbolTable.insert(global);
|
||||
// The symbol table inserts at the end of the module, but globals are a bit
|
||||
// nicer if they are at the beginning.
|
||||
global.getOperation()->moveBefore(&module.front());
|
||||
globals[op.getValue()] = global;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerConstantTensorsToMemref
|
||||
: public LowerConstantTensorsToMemrefBase<LowerConstantTensorsToMemref> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<refback::RefbackDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
GlobalCreator globals(module);
|
||||
|
||||
// With the global traversal factored into GlobalCreator, this could in
|
||||
// principle be done with a pattern.
|
||||
module.walk([&](ConstantOp op) {
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
auto global = globals.getGlobalFor(op.getValue());
|
||||
OpBuilder builder(op);
|
||||
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
||||
auto memref = builder.create<refback::GetGlobalMemrefOp>(
|
||||
op.getLoc(), memrefType, global.getName());
|
||||
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
|
||||
op.replaceAllUsesWith(tensor);
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::createLowerConstantTensorsToMemrefPass() {
|
||||
return std::make_unique<LowerConstantTensorsToMemref>();
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s
|
||||
|
||||
// -----
|
||||
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{must reference a valid symbol}}
|
||||
refback.get_global_memref @nonexistent_symbol : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with shape of global}}
|
||||
refback.get_global_memref @g : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
refback.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with element type of global}}
|
||||
refback.get_global_memref @g : memref<2xi8>
|
||||
return
|
||||
}
|
|
@ -1,11 +1,8 @@
|
|||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: refback.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||
refback.global @foo dense<0.0> : tensor<10xf32>
|
||||
|
||||
// CHECK-LABEL: func @global
|
||||
func @global() {
|
||||
// CHECK: refback.get_global_memref @foo : memref<10xf32>
|
||||
%0 = refback.get_global_memref @foo : memref<10xf32>
|
||||
// CHECK-LABEL: @alloc_memref
|
||||
func @alloc_memref(%arg0: tensor<?xindex>) {
|
||||
// CHECK: refback.alloc_memref
|
||||
%0 = refback.alloc_memref %arg0 : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -21,23 +21,3 @@ refbackrt.module_metadata {
|
|||
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
||||
}
|
||||
func @f() { return }
|
||||
|
||||
// -----
|
||||
|
||||
refbackrt.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{must reference a valid refbackrt.global}}
|
||||
refbackrt.get_global @nonexistent_symbol : memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
refbackrt.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with element type of global}}
|
||||
refbackrt.get_global @g : memref<*xi8>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -11,10 +11,3 @@ refbackrt.module_metadata {
|
|||
func @f(%arg0: !refbackrt.tensor) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: refbackrt.global @g dense<0.0{{.*}}> : tensor<10xf32>
|
||||
refbackrt.global @g dense<0.0> : tensor<10xf32>
|
||||
func @uses_global() {
|
||||
refbackrt.get_global @g : memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memref <%s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
// We check the debug name too since we put some effort into making that readable.
|
||||
// The name isn't load-bearing though.
|
||||
// CHECK: refback.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
|
||||
// CHECK: func @basic
|
||||
func @basic() -> tensor<3x4xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = refback.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
|
||||
// Only one global is created.
|
||||
// CHECK: refback.global
|
||||
// CHECK-NOT: refback.global
|
||||
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<7.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
|
||||
// Two globals are created.
|
||||
// CHECK: refback.global
|
||||
// CHECK: refback.global
|
||||
// CHECK-NOT: refback.global
|
||||
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<8.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
// We don't convert non-tensor globals.
|
||||
// CHECK-NOT: refback.global
|
||||
func @non_tensor() {
|
||||
%0 = constant 7 : i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: }
|
|
@ -1,58 +0,0 @@
|
|||
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float>
|
||||
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32>
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__refbackrt_global_extents_g : !llvm.ptr<array<1 x i32>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__refbackrt_global_data_buffer_g : !llvm.ptr<array<3 x float>>
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm.struct<(i64, ptr<i8>)> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.mul %[[VAL_3]], %[[VAL_4]] : !llvm.i64
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_3]], %[[VAL_7]] : !llvm.i64
|
||||
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_2]] : !llvm.i64
|
||||
// CHECK: %[[VAL_10:.*]] = llvm.mul %[[VAL_9]], %[[VAL_5]] : !llvm.i64
|
||||
// CHECK: %[[VAL_11:.*]] = llvm.add %[[VAL_6]], %[[VAL_10]] : !llvm.i64
|
||||
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(false) : !llvm.i1
|
||||
// CHECK: %[[VAL_13:.*]] = llvm.call @malloc(%[[VAL_11]]) : (!llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: "llvm.intr.memcpy"(%[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_12]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
|
||||
// CHECK: %[[VAL_15:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][0] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: }
|
||||
refbackrt.global @g dense<7.000000e+00> : tensor<3xf32>
|
||||
func @calls_get_global() -> memref<*xf32> {
|
||||
%0 = refbackrt.get_global @g : memref<*xf32>
|
||||
return %0 : memref<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
|
||||
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float>
|
||||
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32>
|
||||
refbackrt.global @g dense<7.0> : tensor<f32>
|
|
@ -17,7 +17,6 @@
|
|||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_identity("identity")
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
|
||||
|
@ -114,7 +113,6 @@ func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
|
|||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
|
||||
|
@ -215,7 +213,6 @@ func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackr
|
|||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
|
||||
// CHECK-LABEL: llvm.func @calls_abort_if(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) {
|
||||
|
|
|
@ -73,22 +73,6 @@ func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK: refbackrt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||
refback.global @g dense<7.0> : tensor<10xf32>
|
||||
// CHECK-LABEL: func @gets_global() -> !refbackrt.tensor
|
||||
func @gets_global() -> memref<10xf32> {
|
||||
// CHECK: %[[GMEMREF:.*]] = refbackrt.get_global @g : memref<*xf32>
|
||||
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||
// CHECK: return %[[RET]] : !refbackrt.tensor
|
||||
%0 = refback.get_global_memref @g : memref<10xf32>
|
||||
return %0 : memref<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test diagnostics.
|
||||
|
||||
// expected-error @+1 {{func not expressible with refbackrt ABI}}
|
||||
|
|
Loading…
Reference in New Issue