From 5227d52c26f73bde4ce81f247fea2a4e6c7f7da2 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 10 Nov 2020 15:14:02 -0800 Subject: [PATCH] [RefBackend] Use std.global_memref instead of homegrown thing This vastly simplifies our code, allowing deleting multiple ops, simplifying multiple passes, and removing a whole pass. Now `refback` dialect is down to one op (refback.alloc_memref, which simplifies allocations to just take a shape instead of individual extents). --- .../npcomp/Dialect/Refback/IR/RefbackOps.td | 26 --- .../Dialect/Refbackrt/IR/RefbackrtOps.td | 37 ---- include/npcomp/RefBackend/Passes.td | 10 - include/npcomp/RefBackend/RefBackend.h | 3 - lib/Dialect/Refback/IR/RefbackOps.cpp | 44 ---- lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp | 41 ---- lib/RefBackend/CMakeLists.txt | 1 - lib/RefBackend/LowerToLLVM.cpp | 202 ------------------ lib/RefBackend/LowerToRefbackrtABI.cpp | 37 ---- lib/RefBackend/RefBackend.cpp | 3 +- .../Runtime/CompilerDataStructures.h | 7 - lib/RefBackend/Runtime/CompilerRuntime.cpp | 8 - .../LowerConstantTensorsToMemref.cpp | 113 ---------- test/Dialect/Refback/invalid.mlir | 31 --- test/Dialect/Refback/ops.mlir | 11 +- test/Dialect/Refbackrt/invalid.mlir | 20 -- test/Dialect/Refbackrt/ops.mlir | 7 - .../lower-constant-tensors-to-memref.mlir | 61 ------ test/RefBackend/lower-to-llvm-global.mlir | 58 ----- test/RefBackend/lower-to-llvm.mlir | 3 - test/RefBackend/lower-to-refbackrt-abi.mlir | 16 -- 21 files changed, 5 insertions(+), 734 deletions(-) delete mode 100644 lib/RefBackend/TensorToMemref/LowerConstantTensorsToMemref.cpp delete mode 100644 test/Dialect/Refback/invalid.mlir delete mode 100644 test/RefBackend/lower-constant-tensors-to-memref.mlir delete mode 100644 test/RefBackend/lower-to-llvm-global.mlir diff --git a/include/npcomp/Dialect/Refback/IR/RefbackOps.td b/include/npcomp/Dialect/Refback/IR/RefbackOps.td index 709e957a3..6c3880f73 100644 --- a/include/npcomp/Dialect/Refback/IR/RefbackOps.td +++ b/include/npcomp/Dialect/Refback/IR/RefbackOps.td @@ -20,21 +20,6 @@ class Refback_Op traits = []> : Op { } -def Refback_GlobalOp : Refback_Op<"global", [Symbol]> { - let summary = "Represents a global variable"; - let description = [{ - Represents a global variable. - - Currently, only constant tensors are supported, and they are not - considered to be exported. - }]; - let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value); - let results = (outs); - - let printer = [{ return ::print$cppClass(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} - //===----------------------------------------------------------------------===// // Ops related to bufferization. //===----------------------------------------------------------------------===// @@ -52,15 +37,4 @@ def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> { let assemblyFormat = "$shape attr-dict `:` type($memref)"; } -def Refback_GetGlobalMemrefOp : Refback_Op<"get_global_memref"> { - let summary = "Obtain a memref pointing at the given global"; - let description = [{ - Obtain a memref pointing at the given global. - }]; - let arguments = (ins FlatSymbolRefAttr:$global); - let results = (outs AnyMemRef:$memref); - let assemblyFormat = "$global attr-dict `:` type($memref)"; - let verifier = "return ::verify$cppClass(*this);"; -} - #endif // REFBACK_OPS diff --git a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td index f504d24f3..732d6e059 100644 --- a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td +++ b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td @@ -46,43 +46,6 @@ def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> { let assemblyFormat = "$pred `,` $msg attr-dict"; } -def Refbackrt_GlobalOp : Refbackrt_Op<"global", [Symbol]> { - let summary = "Represents a global variable"; - let description = [{ - Represents a global variable. - - Currently, only constant tensors are supported, and they are not - considered to be exported. - }]; - let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value); - let results = (outs); - - let printer = [{ return ::print$cppClass(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} - -def Refbackrt_GetGlobalOp : Refbackrt_Op<"get_global"> { - let summary = "Obtain a rank-erased memref pointing at the given global"; - let description = [{ - Obtain a rank-erased memref pointing at the given global. - - TODO: As we define the runtime layer better, we should have fewer - entry points that return memrefs, or at least have a clearer separation - between the "memref world" and the "refbackrt world". - Something like forming IREE dispatch regions seems to be the missing thing: - - Everything inside the dispatch regions gets things marshaled from the - runtime (flow/hal/refbackrt) layer to/from memrefs in a clear way. - - Everything outside the dispatch regions purely uses the runtime - (flow/hal/refbackrt) data structures. - Globals should be one of the things that are purely runtime data structures, - rather than using memrefs. For now, using memrefs is simpler though. - }]; - let arguments = (ins FlatSymbolRefAttr:$global); - let results = (outs AnyUnrankedMemRef:$memref); - let assemblyFormat = "$global attr-dict `:` type($memref)"; - let verifier = "return ::verify$cppClass(*this);"; -} - def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [ SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp"> ]> { diff --git a/include/npcomp/RefBackend/Passes.td b/include/npcomp/RefBackend/Passes.td index 993bd7dfd..0b1bdc17f 100644 --- a/include/npcomp/RefBackend/Passes.td +++ b/include/npcomp/RefBackend/Passes.td @@ -11,15 +11,6 @@ include "mlir/Pass/PassBase.td" -def LowerConstantTensorsToMemref : - Pass<"lower-constant-tensors-to-memref", "ModuleOp"> { - let summary = "Lower std.constant of tensor type to memref"; - let description = [{ - This must be a module pass since it involves creating refback.global ops. - }]; - let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()"; -} - def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> { let summary = "Lower constructs requiring runtime support to `refbackrt`"; let description = [{ @@ -30,7 +21,6 @@ def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> { The constructs requiring runtime support are: - function signatures / module metadata - - globals - error handling }]; let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()"; diff --git a/include/npcomp/RefBackend/RefBackend.h b/include/npcomp/RefBackend/RefBackend.h index ebeaeca13..913301767 100644 --- a/include/npcomp/RefBackend/RefBackend.h +++ b/include/npcomp/RefBackend/RefBackend.h @@ -23,9 +23,6 @@ void registerRefBackendPasses(); // // Pass summaries are in Passes.td. -std::unique_ptr> -createLowerConstantTensorsToMemrefPass(); - std::unique_ptr> createLowerStructuralToMemrefPass(); std::unique_ptr> createLowerToRefbackrtABIPass(); diff --git a/lib/Dialect/Refback/IR/RefbackOps.cpp b/lib/Dialect/Refback/IR/RefbackOps.cpp index c99319b5a..1ca324dee 100644 --- a/lib/Dialect/Refback/IR/RefbackOps.cpp +++ b/lib/Dialect/Refback/IR/RefbackOps.cpp @@ -15,49 +15,5 @@ using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::refback; -//===----------------------------------------------------------------------===// -// GlobalOp -//===----------------------------------------------------------------------===// - -static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) { - p << "refback.global "; - p.printSymbolName(op.sym_name()); - p << ' '; - p.printOptionalAttrDictWithKeyword(op.getAttrs(), - /*elidedAttrs=*/{"sym_name", "value"}); - p.printAttribute(op.valueAttr()); -} - -static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { - StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) - return failure(); - Attribute valueAttr; - if (parser.parseAttribute(valueAttr, "value", result.attributes)) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// GetGlobalMemrefOp -//===----------------------------------------------------------------------===// - -static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) { - auto global = SymbolTable::lookupNearestSymbolFrom(op, op.global()); - if (!global) - return op.emitError() << "must reference a valid symbol"; - auto globalType = global.value().getType(); - auto resultType = op.getType().cast(); - if (failed( - verifyCompatibleShape(globalType.getShape(), resultType.getShape()))) - return op.emitError() << "inconsistent with shape of global"; - if (globalType.getElementType() != resultType.getElementType()) - return op.emitError() << "inconsistent with element type of global"; - return success(); -} - #define GET_OP_CLASSES #include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc" diff --git a/lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp b/lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp index bccede993..1b8f2bee8 100644 --- a/lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp +++ b/lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp @@ -16,47 +16,6 @@ using namespace mlir; using namespace mlir::NPCOMP::refbackrt; -//===----------------------------------------------------------------------===// -// GlobalOp -//===----------------------------------------------------------------------===// - -static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) { - p << "refbackrt.global "; - p.printSymbolName(op.sym_name()); - p << ' '; - p.printOptionalAttrDictWithKeyword(op.getAttrs(), - /*elidedAttrs=*/{"sym_name", "value"}); - p.printAttribute(op.valueAttr()); -} - -static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { - StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) - return failure(); - Attribute valueAttr; - if (parser.parseAttribute(valueAttr, "value", result.attributes)) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// GetGlobalOp -//===----------------------------------------------------------------------===// - -static LogicalResult verifyGetGlobalOp(GetGlobalOp op) { - auto global = SymbolTable::lookupNearestSymbolFrom(op, op.global()); - if (!global) - return op.emitError() << "must reference a valid refbackrt.global"; - auto globalType = global.value().getType(); - auto resultType = op.getType().cast(); - if (globalType.getElementType() != resultType.getElementType()) - return op.emitError() << "inconsistent with element type of global"; - return success(); -} - //===----------------------------------------------------------------------===// // ModuleMetadataOp //===----------------------------------------------------------------------===// diff --git a/lib/RefBackend/CMakeLists.txt b/lib/RefBackend/CMakeLists.txt index 1b5c63255..5db40364d 100644 --- a/lib/RefBackend/CMakeLists.txt +++ b/lib/RefBackend/CMakeLists.txt @@ -5,7 +5,6 @@ add_npcomp_library(NPCOMPRefBackend RefBackend.cpp LowerToLLVM.cpp LowerToRefbackrtABI.cpp - TensorToMemref/LowerConstantTensorsToMemref.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SRC_DIR}/include/npcomp/RefBackend diff --git a/lib/RefBackend/LowerToLLVM.cpp b/lib/RefBackend/LowerToLLVM.cpp index 5028d71dd..074439f3b 100644 --- a/lib/RefBackend/LowerToLLVM.cpp +++ b/lib/RefBackend/LowerToLLVM.cpp @@ -56,22 +56,6 @@ static LLVMType getModuleDescriptorTy(MLIRContext *context) { }); } -// Get the LLVMType for refbackrt::GlobalDescriptor. -static LLVMType getGlobalDescriptorTy(MLIRContext *context) { - return LLVMType::getStructTy( - // std::int32_t numExtents; - LLVMType::getIntNTy(context, 32), - // std::int32_t *extents; - LLVMType::getIntNTy(context, 32).getPointerTo(), - // It is important that this struct member is a type-erased pointer - // so that this type is "context-free" and can be created in conversion - // patterns independently of the actual type of the data stored in the - // buffer. - // - // void *data; - LLVMType::getInt8PtrTy(context)); -} - //===----------------------------------------------------------------------===// // Compiler runtime functions. //===----------------------------------------------------------------------===// @@ -122,35 +106,6 @@ public: }; } // namespace -namespace { -class GetGlobalOpCompilerRuntimeLowering - : public OpConversionPattern { -public: - GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) - : OpConversionPattern(backingFunc.getContext()), - backingFunc(backingFunc) {} - LogicalResult - matchAndRewrite(refbackrt::GetGlobalOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // It would be nice if we could use the constructor here that takes just the - // global, but keeping track of the converted llvm.mlir.global op that gets - // created from the refbackrt.global while conversion is going on is a - // headache. - // - // Instead, we rely on the symbol name being the same and the result type - // always being the same. - auto globalAddr = rewriter.create( - op.getLoc(), - getGlobalDescriptorTy(rewriter.getContext()).getPointerTo(), - op.globalAttr()); - rewriter.replaceOpWithNewOp(op, backingFunc, - ValueRange({globalAddr})); - return success(); - } - LLVM::LLVMFuncOp backingFunc; -}; -} // namespace - static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg, OpBuilder &builder, Location loc) { // TODO: Deduplicate strings. @@ -262,164 +217,8 @@ static void populateCompilerRuntimePatterns(ModuleOp module, "from_memref", funcTy, builder, module.getLoc()); patterns.insert(fromMemrefFunc); } - - { - // Hardcoding f32 is fine here, since unranked memref descriptors have - // identical struct layout / ABI / contents regardless of the element type. - auto mlirFunctionType = builder.getFunctionType( - {getGlobalDescriptorTy(context).getPointerTo()}, - {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}); - LLVMType funcTy = convertFunctionType(mlirFunctionType); - LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl( - "get_global", funcTy, builder, module.getLoc()); - patterns.insert(backingFunc); - } } -//===----------------------------------------------------------------------===// -// Lowering for refbackrt.global -//===----------------------------------------------------------------------===// - -namespace { -class LowerRefbackrtGlobalOp : public OpConversionPattern { -public: - explicit LowerRefbackrtGlobalOp(LLVMTypeConverter &typeConverter) - : OpConversionPattern(&typeConverter.getContext()), - typeConverter(typeConverter) {} - LogicalResult - matchAndRewrite(refbackrt::GlobalOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto *context = rewriter.getContext(); - auto globalDescriptorTy = getGlobalDescriptorTy(context); - - // Create the data buffer. - auto dataBuffer = createGlobalForDenseElementsAttr( - (Twine("__refbackrt_global_data_buffer_") + op.sym_name()).str(), - op.value().cast(), op, rewriter); - - // Create the extents buffer. - auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>( - llvm::map_range(op.value().getType().cast().getShape(), - [](int64_t i) -> int32_t { return i; }))); - auto extentsBuffer = createGlobalForDenseElementsAttr( - (Twine("__refbackrt_global_extents_") + op.sym_name()).str(), - extentsI32, op, rewriter); - - // Create the GlobalDescriptor. - auto globalDescriptorGlobal = rewriter.create( - op.getLoc(), globalDescriptorTy, /*isConstant=*/true, - LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&globalDescriptorGlobal.initializer()); - - // Create the body of the initializer. - Value globalDescriptor = - rewriter.create(op.getLoc(), globalDescriptorTy); - auto updateDescriptor = [&](Value value, - std::initializer_list position) { - globalDescriptor = rewriter.create( - op.getLoc(), globalDescriptor, value, - /*position=*/rewriter.getI32ArrayAttr(position)); - }; - updateDescriptor( - rewriter.create( - op.getLoc(), LLVMType::getIntNTy(context, 32), - rewriter.getI32IntegerAttr( - op.value().getType().cast().getRank())), - {0}); - - // The global is actually an array, so we need to get a bare i32* pointer - // type. We could do this with GEP but it would be more verbose. - auto extentsBufferArrayAddress = - rewriter.create(op.getLoc(), extentsBuffer); - auto extentsBufferAddress = rewriter.create( - op.getLoc(), LLVMType::getIntNTy(context, 32).getPointerTo(), - extentsBufferArrayAddress); - updateDescriptor(extentsBufferAddress, {1}); - - auto dataBufferAddress = - rewriter.create(op.getLoc(), dataBuffer); - auto typeErasedDataBufferAddress = rewriter.create( - op.getLoc(), LLVMType::getInt8PtrTy(context), dataBufferAddress); - updateDescriptor(typeErasedDataBufferAddress, {2}); - rewriter.create(op.getLoc(), globalDescriptor); - - rewriter.eraseOp(op); - return success(); - } - -private: - // TODO: It feels like MLIR core should have better utilities for this. - LLVM::GlobalOp createGlobalForDenseElementsAttr( - StringRef symbolName, DenseElementsAttr elements, refbackrt::GlobalOp op, - ConversionPatternRewriter &rewriter) const { - auto type = elements.getType().cast(); - - // LLVM translation doesn't handle the case of zero-sized tensors, which can - // happen e.g. for the number of extents of a rank-0 (i.e. scalar). - // - // We fake-up a size-1 DenseElementsAttr to use for creating the global. - // That takes up binary space (one element instead of zero), but that seems - // fine. - // - // TODO: LLVM translation in MLIR core should handle this case better. - if (type.getNumElements() == 0) { - auto elementType = type.getElementType(); - Attribute singleElement; - if (elementType.isIntOrIndex()) - singleElement = rewriter.getIntegerAttr(elementType, 0); - else if (elementType.isa()) - singleElement = rewriter.getFloatAttr(elementType, 0); - assert(singleElement && - "could not fake up an element for a zero element tensor"); - type = RankedTensorType::get({1}, elementType); - elements = - DenseElementsAttr::get(type, ArrayRef(singleElement)); - } - - auto llvmType = getLLVMTypeForShapedType(type, op, rewriter); - return rewriter.create( - op.getLoc(), llvmType, - /*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements); - } - - LLVMType getLLVMTypeForShapedType(ShapedType type, refbackrt::GlobalOp op, - ConversionPatternRewriter &rewriter) const { - auto llvmType = - typeConverter.convertType(type.getElementType()).cast(); - - // MLIR->LLVM lowering for globals requires non-scalar data types. So use a - // dummy size-1 array for the scalar case. - // - // TODO: LLVM translation in MLIR core should handle this case better. - if (type.getRank() == 0) - return LLVMType::getArrayTy(llvmType, 1); - - if (!llvmType) { - rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "cannot convert element type " << type.getElementType() - << " to an LLVM type"; - }); - return nullptr; - } - - // Construct an LLVM nested array type for the tensor initializer. - // tensor -> float - // tensor<10xf32> -> [10 x float] - // tensor<2x3xf32> -> [2 x [3 x float]] - assert(type.hasStaticShape()); - auto shape = type.getShape(); - while (!shape.empty()) { - llvmType = LLVMType::getArrayTy(llvmType, shape.back()); - shape = shape.drop_back(); - } - return llvmType; - } - LLVMTypeConverter &typeConverter; -}; -} // namespace - //===----------------------------------------------------------------------===// // Lowering for module metadata //===----------------------------------------------------------------------===// @@ -706,7 +505,6 @@ class LowerToLLVM : public LowerToLLVMBase { target.addLegalOp(); populateStdToLLVMConversionPatterns(converter, patterns); patterns.insert(context); - patterns.insert(converter); // TODO: Move these "std to std" legalizations to their own pass if we grow // lots of these patterns. diff --git a/lib/RefBackend/LowerToRefbackrtABI.cpp b/lib/RefBackend/LowerToRefbackrtABI.cpp index 7012afe32..4d28c4631 100644 --- a/lib/RefBackend/LowerToRefbackrtABI.cpp +++ b/lib/RefBackend/LowerToRefbackrtABI.cpp @@ -74,37 +74,6 @@ static LogicalResult createModuleMetadata(ModuleOp module) { // Dialect conversion. //===----------------------------------------------------------------------===// -namespace { -class LowerGlobalOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(refback::GlobalOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.sym_name(), - op.value()); - return success(); - } -}; -} // namespace - -namespace { -class LowerGetGlobalMemrefOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto abiMemref = rewriter.create( - op.getLoc(), getABIMemrefType(op.getType()), op.global()); - // Cast back to the original type. - rewriter.replaceOpWithNewOp(op, abiMemref, op.getType()); - return success(); - } -}; -} // namespace - namespace { class LowerAssertOp : public OpConversionPattern { public: @@ -216,12 +185,6 @@ static LogicalResult doDialectConversion(ModuleOp module) { target.addDynamicallyLegalOp( [&](ReturnOp op) { return typeConverter.isLegal(op); }); - patterns.insert(context); - target.addIllegalOp(); - - patterns.insert(context); - target.addIllegalOp(); - patterns.insert(context); target.addIllegalOp(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index f2e10dd2f..85cce80e2 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -228,8 +228,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // Bufferize the TCP dialect. pm.addNestedPass(createTCPBufferizePass()); - // Lower tensor-valued constants to refback.global. - pm.addPass(createLowerConstantTensorsToMemrefPass()); + pm.addPass(createTensorConstantBufferizePass()); // refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. // We need to resolve this to std.alloc which takes individual extents. pm.addNestedPass(createLowerAllocMemRefOpsPass()); diff --git a/lib/RefBackend/Runtime/CompilerDataStructures.h b/lib/RefBackend/Runtime/CompilerDataStructures.h index ae669f6bf..76b986608 100644 --- a/lib/RefBackend/Runtime/CompilerDataStructures.h +++ b/lib/RefBackend/Runtime/CompilerDataStructures.h @@ -48,13 +48,6 @@ struct ModuleDescriptor { FuncDescriptor *functionDescriptors; }; -// Static data representing a global variable (together with its shape). -struct GlobalDescriptor { - std::int32_t numExtents; - std::int32_t *extents; - void *data; -}; - } // namespace refbackrt #endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H diff --git a/lib/RefBackend/Runtime/CompilerRuntime.cpp b/lib/RefBackend/Runtime/CompilerRuntime.cpp index 1b15fa30c..3bfae89b0 100644 --- a/lib/RefBackend/Runtime/CompilerRuntime.cpp +++ b/lib/RefBackend/Runtime/CompilerRuntime.cpp @@ -114,11 +114,3 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank, return Tensor::createRaw(ArrayRef(extents32Buf.data(), rank), elementType, data); } - -extern "C" UnrankedMemref -__npcomp_compiler_rt_get_global(GlobalDescriptor *global) { - auto *descriptor = MemrefDescriptor::create( - ArrayRef(global->extents, global->numExtents), - global->data); - return UnrankedMemref{global->numExtents, descriptor}; -} diff --git a/lib/RefBackend/TensorToMemref/LowerConstantTensorsToMemref.cpp b/lib/RefBackend/TensorToMemref/LowerConstantTensorsToMemref.cpp deleted file mode 100644 index dc3b0077f..000000000 --- a/lib/RefBackend/TensorToMemref/LowerConstantTensorsToMemref.cpp +++ /dev/null @@ -1,113 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "../PassDetail.h" -#include "npcomp/RefBackend/RefBackend.h" - -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/DialectConversion.h" -#include "npcomp/Dialect/Refback/IR/RefbackDialect.h" -#include "npcomp/Dialect/Refback/IR/RefbackOps.h" - -using namespace mlir; -using namespace mlir::NPCOMP; - -//===----------------------------------------------------------------------===// -// LowerConstantTensorsToMemref -//===----------------------------------------------------------------------===// - -namespace { -// This class creates global ops for all tensor-valued constants in the program. -// It creates them with pretty names and makes sure that duplicate globals -// aren't created. -class GlobalCreator { -public: - explicit GlobalCreator(ModuleOp module); - refback::GlobalOp getGlobalFor(Attribute attr) { - assert(globals.find(attr) != globals.end() && "unknown constant attr"); - return globals[attr]; - } - -private: - DenseMap globals; -}; - -GlobalCreator::GlobalCreator(ModuleOp module) { - // Create a builder without an insertion point. We will insert using the - // symbol table to guarantee unique names. - OpBuilder globalBuilder(module.getContext()); - SymbolTable symbolTable(module); - module.walk([&](ConstantOp op) { - // We only want tensor constants for now. - auto type = op.getType().dyn_cast(); - if (!type) - return; - // If we already have a global for this constant value, no need to do - // anything else. - auto it = globals.find(op.getValue()); - if (it != globals.end()) - return; - - // Create a pretty name. - SmallString<64> buf; - llvm::raw_svector_ostream os(buf); - interleave(type.getShape(), os, "x"); - os << "x" << type.getElementType(); - - auto global = globalBuilder.create( - op.getLoc(), (Twine("__constant_") + os.str()).str(), - op.getValue().cast()); - symbolTable.insert(global); - // The symbol table inserts at the end of the module, but globals are a bit - // nicer if they are at the beginning. - global.getOperation()->moveBefore(&module.front()); - globals[op.getValue()] = global; - }); -} -} // namespace - -namespace { -class LowerConstantTensorsToMemref - : public LowerConstantTensorsToMemrefBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - auto module = getOperation(); - GlobalCreator globals(module); - - // With the global traversal factored into GlobalCreator, this could in - // principle be done with a pattern. - module.walk([&](ConstantOp op) { - auto type = op.getType().dyn_cast(); - if (!type) - return; - auto global = globals.getGlobalFor(op.getValue()); - OpBuilder builder(op); - auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - auto memref = builder.create( - op.getLoc(), memrefType, global.getName()); - Value tensor = builder.create(op.getLoc(), type, memref); - op.replaceAllUsesWith(tensor); - op.erase(); - }); - } -}; -} // namespace - -std::unique_ptr> -mlir::NPCOMP::createLowerConstantTensorsToMemrefPass() { - return std::make_unique(); -} diff --git a/test/Dialect/Refback/invalid.mlir b/test/Dialect/Refback/invalid.mlir deleted file mode 100644 index d7813a285..000000000 --- a/test/Dialect/Refback/invalid.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s - -// ----- - -refback.global @g dense<0.0> : tensor<2xf32> - -func @f() { - // expected-error @+1 {{must reference a valid symbol}} - refback.get_global_memref @nonexistent_symbol : memref<3xf32> - return -} - -// ----- - -refback.global @g dense<0.0> : tensor<2xf32> - -func @f() { - // expected-error @+1 {{inconsistent with shape of global}} - refback.get_global_memref @g : memref<3xf32> - return -} - -// ----- - -refback.global @g dense<0.0> : tensor<2xf32> - -func @f() { - // expected-error @+1 {{inconsistent with element type of global}} - refback.get_global_memref @g : memref<2xi8> - return -} diff --git a/test/Dialect/Refback/ops.mlir b/test/Dialect/Refback/ops.mlir index aa3f66140..199eb40c8 100644 --- a/test/Dialect/Refback/ops.mlir +++ b/test/Dialect/Refback/ops.mlir @@ -1,11 +1,8 @@ // RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s -// CHECK-LABEL: refback.global @foo dense<0.0{{.*}}> : tensor<10xf32> -refback.global @foo dense<0.0> : tensor<10xf32> - -// CHECK-LABEL: func @global -func @global() { - // CHECK: refback.get_global_memref @foo : memref<10xf32> - %0 = refback.get_global_memref @foo : memref<10xf32> +// CHECK-LABEL: @alloc_memref +func @alloc_memref(%arg0: tensor) { + // CHECK: refback.alloc_memref + %0 = refback.alloc_memref %arg0 : memref return } diff --git a/test/Dialect/Refbackrt/invalid.mlir b/test/Dialect/Refbackrt/invalid.mlir index d33207df7..45609a826 100644 --- a/test/Dialect/Refbackrt/invalid.mlir +++ b/test/Dialect/Refbackrt/invalid.mlir @@ -21,23 +21,3 @@ refbackrt.module_metadata { refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32} } func @f() { return } - -// ----- - -refbackrt.global @g dense<0.0> : tensor<2xf32> - -func @f() { - // expected-error @+1 {{must reference a valid refbackrt.global}} - refbackrt.get_global @nonexistent_symbol : memref<*xf32> - return -} - -// ----- - -refbackrt.global @g dense<0.0> : tensor<2xf32> - -func @f() { - // expected-error @+1 {{inconsistent with element type of global}} - refbackrt.get_global @g : memref<*xi8> - return -} diff --git a/test/Dialect/Refbackrt/ops.mlir b/test/Dialect/Refbackrt/ops.mlir index 02ccbc20f..bf67058ca 100644 --- a/test/Dialect/Refbackrt/ops.mlir +++ b/test/Dialect/Refbackrt/ops.mlir @@ -11,10 +11,3 @@ refbackrt.module_metadata { func @f(%arg0: !refbackrt.tensor) { return } - -// CHECK-LABEL: refbackrt.global @g dense<0.0{{.*}}> : tensor<10xf32> -refbackrt.global @g dense<0.0> : tensor<10xf32> -func @uses_global() { - refbackrt.get_global @g : memref<*xf32> - return -} diff --git a/test/RefBackend/lower-constant-tensors-to-memref.mlir b/test/RefBackend/lower-constant-tensors-to-memref.mlir deleted file mode 100644 index 5a5d7c01b..000000000 --- a/test/RefBackend/lower-constant-tensors-to-memref.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memref <%s | FileCheck %s - -// CHECK-LABEL: module { -// We check the debug name too since we put some effort into making that readable. -// The name isn't load-bearing though. -// CHECK: refback.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32> -// CHECK: func @basic -func @basic() -> tensor<3x4xf32> { - // CHECK: %[[MEMREF:.*]] = refback.get_global_memref @__constant_3x4xf32 : memref<3x4xf32> - // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]] - %0 = constant dense<7.0> : tensor<3x4xf32> - // CHECK: return %[[TENSOR]] - return %0 : tensor<3x4xf32> -} - -// CHECK: } - -// ----- - -// CHECK-LABEL: module { - -// Only one global is created. -// CHECK: refback.global -// CHECK-NOT: refback.global -func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) { - %0 = constant dense<7.0> : tensor<3x4xf32> - %1 = constant dense<7.0> : tensor<3x4xf32> - // CHECK: return %[[TENSOR]] - return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32> -} - -// CHECK: } - -// ----- - -// CHECK-LABEL: module { - -// Two globals are created. -// CHECK: refback.global -// CHECK: refback.global -// CHECK-NOT: refback.global -func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) { - %0 = constant dense<7.0> : tensor<3x4xf32> - %1 = constant dense<8.0> : tensor<3x4xf32> - // CHECK: return %[[TENSOR]] - return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32> -} - -// CHECK: } - -// ----- - -// CHECK-LABEL: module { -// We don't convert non-tensor globals. -// CHECK-NOT: refback.global -func @non_tensor() { - %0 = constant 7 : i32 - return -} - -// CHECK: } diff --git a/test/RefBackend/lower-to-llvm-global.mlir b/test/RefBackend/lower-to-llvm-global.mlir deleted file mode 100644 index a8619c373..000000000 --- a/test/RefBackend/lower-to-llvm-global.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail - -// CHECK-LABEL: llvm.func @malloc(!llvm.i64) -> !llvm.ptr -// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) -// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr, ptr)>>) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float> -// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32> - -// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr, ptr)> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr, ptr)> -// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr, ptr)> -// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__refbackrt_global_extents_g : !llvm.ptr> -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr> to !llvm.ptr -// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr, ptr)> -// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__refbackrt_global_data_buffer_g : !llvm.ptr> -// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr> to !llvm.ptr -// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr, ptr)> -// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr, ptr)> -// CHECK: } - -// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm.struct<(i64, ptr)> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm.ptr, ptr)>> -// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm.ptr, ptr)>>) -> !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 -// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 -// CHECK: %[[VAL_6:.*]] = llvm.mul %[[VAL_3]], %[[VAL_4]] : !llvm.i64 -// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_3]], %[[VAL_7]] : !llvm.i64 -// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_2]] : !llvm.i64 -// CHECK: %[[VAL_10:.*]] = llvm.mul %[[VAL_9]], %[[VAL_5]] : !llvm.i64 -// CHECK: %[[VAL_11:.*]] = llvm.add %[[VAL_6]], %[[VAL_10]] : !llvm.i64 -// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(false) : !llvm.i1 -// CHECK: %[[VAL_13:.*]] = llvm.call @malloc(%[[VAL_11]]) : (!llvm.i64) -> !llvm.ptr -// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(i64, ptr)> -// CHECK: "llvm.intr.memcpy"(%[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_12]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () -// CHECK: %[[VAL_15:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][0] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr)> -// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr)> -// CHECK: } -refbackrt.global @g dense<7.000000e+00> : tensor<3xf32> -func @calls_get_global() -> memref<*xf32> { - %0 = refbackrt.get_global @g : memref<*xf32> - return %0 : memref<*xf32> -} - -// ----- - -// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy. -// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor) : !llvm.array<1 x float> -// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32> -refbackrt.global @g dense<7.0> : tensor diff --git a/test/RefBackend/lower-to-llvm.mlir b/test/RefBackend/lower-to-llvm.mlir index 8c7cfca2a..cdde32cec 100644 --- a/test/RefBackend/lower-to-llvm.mlir +++ b/test/RefBackend/lower-to-llvm.mlir @@ -17,7 +17,6 @@ // CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) // CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> // CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr, ptr)>>) -> !llvm.struct<(i64, ptr)> // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_identity("identity") // CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> { @@ -114,7 +113,6 @@ func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor { // CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) // CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> // CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr, ptr)>>) -> !llvm.struct<(i64, ptr)> // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0") // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1") // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2") @@ -215,7 +213,6 @@ func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackr // CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) // CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> // CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr, ptr)>>) -> !llvm.struct<(i64, ptr)> // CHECK-LABEL: llvm.func @calls_abort_if( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) { diff --git a/test/RefBackend/lower-to-refbackrt-abi.mlir b/test/RefBackend/lower-to-refbackrt-abi.mlir index 13956968b..81071da7e 100644 --- a/test/RefBackend/lower-to-refbackrt-abi.mlir +++ b/test/RefBackend/lower-to-refbackrt-abi.mlir @@ -73,22 +73,6 @@ func @multiple_blocks(%arg0: memref) -> memref { // ----- - -// CHECK: refbackrt.global @g dense<7.000000e+00> : tensor<10xf32> -refback.global @g dense<7.0> : tensor<10xf32> -// CHECK-LABEL: func @gets_global() -> !refbackrt.tensor -func @gets_global() -> memref<10xf32> { -// CHECK: %[[GMEMREF:.*]] = refbackrt.get_global @g : memref<*xf32> -// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32> -// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32> -// CHECK: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32> -// CHECK: return %[[RET]] : !refbackrt.tensor - %0 = refback.get_global_memref @g : memref<10xf32> - return %0 : memref<10xf32> -} - -// ----- - // Test diagnostics. // expected-error @+1 {{func not expressible with refbackrt ABI}}