From e228aa4b11376f4e10376e2b367fbd41cf4dc572 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Fri, 10 Jul 2020 17:31:24 -0700 Subject: [PATCH] npcomprt: add support for constants - create tcp.global + tcp.get_global_memref - create npcomprt.global + npcomprt.get_global - LLVM lowering for new npcomprt ops - Runtime: - GlobalDescriptor struct emitted by LLVM lowering - implement __npcomp_compiler_rt_get_global Also, - cleanly isolate all runtime data structure definitions shared by the compiler and runtime into lib/runtime/CompilerDataStructures.h --- .../Dialect/Npcomprt/IR/NpcomprtBase.td | 2 +- .../npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h | 1 + .../npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td | 39 ++++ include/npcomp/Dialect/TCP/IR/TCPOps.h | 1 + include/npcomp/Dialect/TCP/IR/TCPOps.td | 27 +++ include/npcomp/E2E/E2E.h | 3 + include/npcomp/E2E/Passes.td | 9 + lib/Dialect/Npcomprt/IR/NpcomprtOps.cpp | 42 ++++ lib/Dialect/TCP/IR/TCPOps.cpp | 45 ++++ lib/E2E/E2E.cpp | 3 +- lib/E2E/LowerToHybridTensorMemRef.cpp | 86 +++++++ lib/E2E/LowerToLLVM.cpp | 209 ++++++++++++++++++ lib/E2E/LowerToNpcomprtABI.cpp | 41 ++++ lib/runtime/CompilerDataStructures.h | 60 +++++ lib/runtime/CompilerRuntime.cpp | 9 + lib/runtime/Runtime.cpp | 34 +-- test/Dialect/Npcomprt/invalid.mlir | 20 ++ test/Dialect/Npcomprt/ops.mlir | 6 + test/Dialect/TCP/invalid.mlir | 31 +++ test/Dialect/TCP/ops.mlir | 4 + .../lower-constant-tensors-to-memrefs.mlir | 61 +++++ test/E2E/lower-to-llvm-global.mlir | 32 +++ test/E2E/lower-to-npcomprt-abi.mlir | 17 ++ test/npcomp-run-mlir/constant-add-scalar.mlir | 12 + test/npcomp-run-mlir/constant-add.mlir | 12 + test/npcomp-run-mlir/constant.mlir | 10 + tools/bash_helpers.sh | 2 + 27 files changed, 784 insertions(+), 34 deletions(-) create mode 100644 lib/runtime/CompilerDataStructures.h create mode 100644 test/Dialect/TCP/invalid.mlir create mode 100644 test/E2E/lower-constant-tensors-to-memrefs.mlir create mode 100644 test/E2E/lower-to-llvm-global.mlir create mode 100644 test/npcomp-run-mlir/constant-add-scalar.mlir create mode 100644 test/npcomp-run-mlir/constant-add.mlir create mode 100644 test/npcomp-run-mlir/constant.mlir diff --git a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td index 826170e75..514483a2f 100644 --- a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td +++ b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td @@ -27,7 +27,7 @@ def Npcomprt_Tensor : DialectType< Npcomprt_Dialect, CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">, - "buffer view">, + "npcomprt.tensor">, BuildableType< "$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> { let typeDescription = [{The runtime type that represents a buffer.}]; diff --git a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h index e1dd0acb9..ccbad2fbf 100644 --- a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h +++ b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h @@ -12,6 +12,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" namespace mlir { namespace NPCOMP { diff --git a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td index 3aa13cc7d..880fd742f 100644 --- a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td +++ b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td @@ -10,6 +10,7 @@ #define NPCOMPRT_OPS include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td" +include "mlir/IR/SymbolInterfaces.td" class Npcomprt_Op traits = []> : Op { @@ -57,6 +58,44 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> { let assemblyFormat = "$pred attr-dict"; } +def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> { + let summary = "Represents a global variable"; + let description = [{ + Represents a global variable. + + Currently, only constant tensors are supported, and they are not + considered to be exported. + }]; + let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value); + let results = (outs); + + let printer = [{ return ::print$cppClass(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> { + let summary = "Obtain a rank-erased memref pointing at the given global"; + let description = [{ + Obtain a rank-erased memref pointing at the given global. + + TODO: As we define the runtime layer better, we should have fewer + entry points that return memrefs, or at least have a clearer separation + between the "memref world" and the "npcomprt world". + Something like forming IREE dispatch regions seems to be the missing thing: + - Everything inside the dispatch regions gets things marshaled from the + runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way. + - Everything outside the dispatch regions purely uses the runtime + (flow/hal/npcomprt) data structures. + Globals should be one of the things that are purely runtime data structures, + rather than using memrefs. For now, using memrefs is simpler though. + }]; + let arguments = (ins FlatSymbolRefAttr:$global); + let results = (outs AnyUnrankedMemRef:$memref); + let assemblyFormat = "$global attr-dict `:` type($memref)"; + let verifier = "return ::verify$cppClass(*this);"; + // TODO: verify exists and shape is compatible +} + def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [ SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp"> ]> { diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.h b/include/npcomp/Dialect/TCP/IR/TCPOps.h index 86a1eba28..654b0eee1 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.h +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.h @@ -13,6 +13,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.td b/include/npcomp/Dialect/TCP/IR/TCPOps.td index 5fcf4e219..da6046ac5 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.td +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.td @@ -13,6 +13,7 @@ include "npcomp/Dialect/TCP/IR/TCPBase.td" include "mlir/Dialect/Shape/IR/ShapeBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/SymbolInterfaces.td" class TCP_Op traits = []> : Op { @@ -58,6 +59,32 @@ Allocates a memref of the given shape. let assemblyFormat = "$shape attr-dict `:` type($memref)"; } +def TCP_GlobalOp : TCP_Op<"global", [Symbol]> { + let summary = "Represents a global variable"; + let description = [{ + Represents a global variable. + + Currently, only constant tensors are supported, and they are not + considered to be exported. + }]; + let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value); + let results = (outs); + + let printer = [{ return ::print$cppClass(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> { + let summary = "Obtain a memref pointing at the given global"; + let description = [{ + Obtain a memref pointing at the given global. + }]; + let arguments = (ins FlatSymbolRefAttr:$global); + let results = (outs AnyMemRef:$memref); + let assemblyFormat = "$global attr-dict `:` type($memref)"; + let verifier = "return ::verify$cppClass(*this);"; +} + // TODO: Change to a more principled error handling mechanism. // This op probably doesn't need to exist eventually. // This op is also not correctly modeled right now, since it itself doesn't diff --git a/include/npcomp/E2E/E2E.h b/include/npcomp/E2E/E2E.h index bc07d6955..dead7c08b 100644 --- a/include/npcomp/E2E/E2E.h +++ b/include/npcomp/E2E/E2E.h @@ -25,6 +25,9 @@ std::unique_ptr> createLowerBroadcastToToLoopsPass(); std::unique_ptr> createLowerLinalgOnTensorToLinalgOnMemrefPass(); +std::unique_ptr> +createLowerConstantTensorsToMemrefsPass(); + std::unique_ptr> createResolveShapeOfOpsPass(); std::unique_ptr> createResolveTensorLoadStoreOpsPass(); diff --git a/include/npcomp/E2E/Passes.td b/include/npcomp/E2E/Passes.td index 473fe744f..97e17d363 100644 --- a/include/npcomp/E2E/Passes.td +++ b/include/npcomp/E2E/Passes.td @@ -23,6 +23,15 @@ def LowerBroadcastToToLoops : let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()"; } +def LowerConstantTensorsToMemrefs : + Pass<"lower-constant-tensors-to-memrefs", "ModuleOp"> { + let summary = "Lower std.constant of tensor type to hybrid tensor/memref."; + let description = [{ + This has to be a module pass since it involves creating tcp.global ops. + }]; + let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass()"; +} + def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> { let summary = "Resolve shape.shape_of ops to other shapes."; let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()"; diff --git a/lib/Dialect/Npcomprt/IR/NpcomprtOps.cpp b/lib/Dialect/Npcomprt/IR/NpcomprtOps.cpp index f1d80c39d..5dae7e5a6 100644 --- a/lib/Dialect/Npcomprt/IR/NpcomprtOps.cpp +++ b/lib/Dialect/Npcomprt/IR/NpcomprtOps.cpp @@ -10,11 +10,53 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeUtilities.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" using namespace mlir; using namespace mlir::NPCOMP::npcomprt; +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) { + p << "npcomprt.global "; + p.printSymbolName(op.sym_name()); + p << ' '; + p.printOptionalAttrDictWithKeyword(op.getAttrs(), + /*elidedAttrs=*/{"sym_name", "value"}); + p.printAttribute(op.valueAttr()); +} + +static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + Attribute valueAttr; + if (parser.parseAttribute(valueAttr, "value", result.attributes)) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalOp +//===----------------------------------------------------------------------===// + +static LogicalResult verifyGetGlobalOp(GetGlobalOp op) { + auto global = SymbolTable::lookupNearestSymbolFrom(op, op.global()); + if (!global) + return op.emitError() << "must reference a valid npcomprt.global"; + auto globalType = global.value().getType(); + auto resultType = op.getType().cast(); + if (globalType.getElementType() != resultType.getElementType()) + return op.emitError() << "inconsistent with element type of global"; + return success(); +} + //===----------------------------------------------------------------------===// // ModuleMetadataOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TCP/IR/TCPOps.cpp b/lib/Dialect/TCP/IR/TCPOps.cpp index a0289b230..417b81b8a 100644 --- a/lib/Dialect/TCP/IR/TCPOps.cpp +++ b/lib/Dialect/TCP/IR/TCPOps.cpp @@ -8,11 +8,56 @@ #include "npcomp/Dialect/TCP/IR/TCPOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/TypeUtilities.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::tcp; +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) { + p << "tcp.global "; + p.printSymbolName(op.sym_name()); + p << ' '; + p.printOptionalAttrDictWithKeyword(op.getAttrs(), + /*elidedAttrs=*/{"sym_name", "value"}); + p.printAttribute(op.valueAttr()); +} + +static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + Attribute valueAttr; + if (parser.parseAttribute(valueAttr, "value", result.attributes)) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) { + auto global = SymbolTable::lookupNearestSymbolFrom(op, op.global()); + if (!global) + return op.emitError() << "must reference a valid symbol"; + auto globalType = global.value().getType(); + auto resultType = op.getType().cast(); + if (failed( + verifyCompatibleShape(globalType.getShape(), resultType.getShape()))) + return op.emitError() << "inconsistent with shape of global"; + if (globalType.getElementType() != resultType.getElementType()) + return op.emitError() << "inconsistent with element type of global"; + return success(); +} + //===----------------------------------------------------------------------===// // ShapeObserveErrorOp //===----------------------------------------------------------------------===// diff --git a/lib/E2E/E2E.cpp b/lib/E2E/E2E.cpp index 5ff554c83..803ce6d3a 100644 --- a/lib/E2E/E2E.cpp +++ b/lib/E2E/E2E.cpp @@ -359,7 +359,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline( pm.addPass(createResolveTensorLoadStoreOpsPass()); // At this point, the IR is in a form where there are no tensor ops - // (except tensor_store's of arguments and tensor_load's of returns). + // (except tensor_store's of arguments, tensor_load's of returns, and + // constants). // // This is a reasonable representation for doing buffer assignment. // TODO: Do buffer assignment here. diff --git a/lib/E2E/LowerToHybridTensorMemRef.cpp b/lib/E2E/LowerToHybridTensorMemRef.cpp index 924d9a4c2..84ac6e57a 100644 --- a/lib/E2E/LowerToHybridTensorMemRef.cpp +++ b/lib/E2E/LowerToHybridTensorMemRef.cpp @@ -296,6 +296,91 @@ mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() { return std::make_unique(); } +//===----------------------------------------------------------------------===// +// LowerConstantTensorsToMemrefs +//===----------------------------------------------------------------------===// + +namespace { +// This class creates global ops for all tensor-valued constants in the program. +// It creates them with pretty names and makes sure that duplicate globals +// aren't created. +class GlobalCreator { +public: + explicit GlobalCreator(ModuleOp module); + tcp::GlobalOp getGlobalFor(Attribute attr) { + assert(globals.find(attr) != globals.end() && "unknown constant attr"); + return globals[attr]; + } + +private: + DenseMap globals; +}; + +GlobalCreator::GlobalCreator(ModuleOp module) { + // Create a builder without an insertion point. We will insert using the + // symbol table to guarantee unique names. + OpBuilder globalBuilder(module.getContext()); + SymbolTable symbolTable(module); + module.walk([&](ConstantOp op) { + // We only want tensor constants for now. + auto type = op.getType().dyn_cast(); + if (!type) + return; + // If we already have a global for this constant value, no need to do + // anything else. + auto it = globals.find(op.getValue()); + if (it != globals.end()) + return; + + // Create a pretty name. + SmallString<64> buf; + llvm::raw_svector_ostream os(buf); + interleave(type.getShape(), os, "x"); + os << "x" << type.getElementType(); + + auto global = globalBuilder.create( + op.getLoc(), (Twine("__constant_") + os.str()).str(), + op.getValue().cast()); + symbolTable.insert(global); + // The symbol table inserts at the end of the module, but globals are a bit + // nicer if they are at the beginning. + global.getOperation()->moveBefore(&module.front()); + globals[op.getValue()] = global; + }); +} +} // namespace + +namespace { +class LowerConstantTensorsToMemrefs + : public LowerConstantTensorsToMemrefsBase { + void runOnOperation() { + auto module = getOperation(); + GlobalCreator globals(module); + + // With the global traversal factored into GlobalCreator, this could in + // principle be done with a pattern. + module.walk([&](ConstantOp op) { + auto type = op.getType().dyn_cast(); + if (!type) + return; + auto global = globals.getGlobalFor(op.getValue()); + OpBuilder builder(op); + auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); + auto memref = builder.create( + op.getLoc(), memrefType, global.getName()); + Value tensor = builder.create(op.getLoc(), type, memref); + op.replaceAllUsesWith(tensor); + op.erase(); + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass() { + return std::make_unique(); +} + void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) { // Lower to hybrid tensor/memref. // The invariant of "hybrid tensor/memref" is that the core computation @@ -305,6 +390,7 @@ void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) { // allocated with alloc_shape ops. // Thus, shape.shape_of ops on the original tensors in the program can be // resolved to the shapes in the alloc_memref calls. + pm.addPass(createLowerConstantTensorsToMemrefsPass()); pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass()); pm.addPass(createLowerBroadcastToToLoopsPass()); } diff --git a/lib/E2E/LowerToLLVM.cpp b/lib/E2E/LowerToLLVM.cpp index 33b1f9c34..9aa715d8b 100644 --- a/lib/E2E/LowerToLLVM.cpp +++ b/lib/E2E/LowerToLLVM.cpp @@ -22,6 +22,28 @@ using namespace mlir::NPCOMP; using mlir::LLVM::LLVMFuncOp; using mlir::LLVM::LLVMType; +//===----------------------------------------------------------------------===// +// Utilities. +//===----------------------------------------------------------------------===// + +// TODO: Move other descriptor types to here. + +// Get the LLVMType for npcomprt::GlobalDescriptor. +static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) { + return LLVMType::getStructTy( + // std::int32_t numExtents; + LLVMType::getIntNTy(llvmDialect, 32), + // std::int32_t *extents; + LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(), + // It is important that this struct member is a type-erased pointer + // so that this type is "context-free" and can be created in conversion + // patterns independently of the actual type of the data stored in the + // buffer. + // + // void *data; + LLVMType::getInt8PtrTy(llvmDialect)); +} + //===----------------------------------------------------------------------===// // Compiler runtime functions. //===----------------------------------------------------------------------===// @@ -72,6 +94,36 @@ public: }; } // namespace +namespace { +class GetGlobalOpCompilerRuntimeLowering + : public OpConversionPattern { +public: + GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) + : OpConversionPattern(backingFunc.getContext()), + backingFunc(backingFunc) {} + LogicalResult + matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto llvmDialect = + rewriter.getContext()->getRegisteredDialect(); + // It would be nice if we could use the constructor here that takes just the + // global, but keeping track of the converted llvm.mlir.global op that gets + // created from the npcomprt.global while conversion is going on is a + // headache. + // + // Instead, we rely on the symbol name being the same and the result type + // always being the same. + auto globalAddr = rewriter.create( + op.getLoc(), getGlobalDescriptorTy(llvmDialect).getPointerTo(), + op.globalAttr()); + rewriter.replaceOpWithNewOp(op, backingFunc, + ValueRange({globalAddr})); + return success(); + } + LLVM::LLVMFuncOp backingFunc; +}; +} // namespace + // Create the LLVM runtime function backing the npcomprt op with name `name` // and requiring `type`. static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type, @@ -140,8 +192,164 @@ static void populateCompilerRuntimePatterns(ModuleOp module, "from_memref", funcTy, builder, module.getLoc()); patterns.insert(fromMemrefFunc); } + + { + // Hardcoding f32 is fine here, since unranked memref descriptors have + // identical struct layout / ABI / contents regardless of the element type. + auto mlirFunctionType = builder.getFunctionType( + {getGlobalDescriptorTy(llvmDialect).getPointerTo()}, + {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}); + LLVMType funcTy = convertFunctionType(mlirFunctionType); + LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl( + "get_global", funcTy, builder, module.getLoc()); + patterns.insert(backingFunc); + } } +//===----------------------------------------------------------------------===// +// Lowering for npcomprt.global +//===----------------------------------------------------------------------===// + +namespace { +class LowerNpcomprtGlobalOp : public OpConversionPattern { +public: + explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter) + : OpConversionPattern(&typeConverter.getContext()), + typeConverter(typeConverter) {} + LogicalResult + matchAndRewrite(npcomprt::GlobalOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto llvmDialect = typeConverter.getDialect(); + auto globalDescriptorTy = getGlobalDescriptorTy(llvmDialect); + + // Create the data buffer. + auto dataBuffer = createGlobalForDenseElementsAttr( + (Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(), + op.value().cast(), op, rewriter); + + // Create the extents buffer. + auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>( + llvm::map_range(op.value().getType().cast().getShape(), + [](int64_t i) -> int32_t { return i; }))); + auto extentsBuffer = createGlobalForDenseElementsAttr( + (Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32, + op, rewriter); + + // Create the GlobalDescriptor. + auto globalDescriptorGlobal = rewriter.create( + op.getLoc(), globalDescriptorTy, /*isConstant=*/true, + LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.createBlock(&globalDescriptorGlobal.initializer()); + + // Create the body of the initializer. + Value globalDescriptor = + rewriter.create(op.getLoc(), globalDescriptorTy); + auto updateDescriptor = [&](Value value, + std::initializer_list position) { + globalDescriptor = rewriter.create( + op.getLoc(), globalDescriptor, value, + /*position=*/rewriter.getI32ArrayAttr(position)); + }; + updateDescriptor( + rewriter.create( + op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32), + rewriter.getI32IntegerAttr( + op.value().getType().cast().getRank())), + {0}); + + // The global is actually an array, so we need to get a bare i32* pointer + // type. We could do this with GEP but it would be more verbose. + auto extentsBufferArrayAddress = + rewriter.create(op.getLoc(), extentsBuffer); + auto extentsBufferAddress = rewriter.create( + op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(), + extentsBufferArrayAddress); + updateDescriptor(extentsBufferAddress, {1}); + + auto dataBufferAddress = + rewriter.create(op.getLoc(), dataBuffer); + auto typeErasedDataBufferAddress = rewriter.create( + op.getLoc(), LLVMType::getInt8PtrTy(llvmDialect), dataBufferAddress); + updateDescriptor(typeErasedDataBufferAddress, {2}); + rewriter.create(op.getLoc(), globalDescriptor); + + rewriter.eraseOp(op); + return success(); + } + +private: + // TODO: It feels like MLIR core should have better utilities for this. + LLVM::GlobalOp createGlobalForDenseElementsAttr( + StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op, + ConversionPatternRewriter &rewriter) const { + auto type = elements.getType().cast(); + + // LLVM translation doesn't handle the case of zero-sized tensors, which can + // happen e.g. for the number of extents of a rank-0 (i.e. scalar). + // + // We fake-up a size-1 DenseElementsAttr to use for creating the global. + // That takes up binary space (one element instead of zero), but that seems + // fine. + // + // TODO: LLVM translation in MLIR core should handle this case better. + if (type.getNumElements() == 0) { + auto elementType = type.getElementType(); + Attribute singleElement; + if (elementType.isIntOrIndex()) + singleElement = rewriter.getIntegerAttr(elementType, 0); + else if (elementType.isa()) + singleElement = rewriter.getFloatAttr(elementType, 0); + assert(singleElement && + "could not fake up an element for a zero element tensor"); + type = RankedTensorType::get({1}, elementType); + elements = + DenseElementsAttr::get(type, ArrayRef(singleElement)); + } + + auto llvmType = getLLVMTypeForShapedType(type, op, rewriter); + return rewriter.create( + op.getLoc(), llvmType, + /*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements); + } + + LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op, + ConversionPatternRewriter &rewriter) const { + auto llvmType = + typeConverter.convertType(type.getElementType()).cast(); + + // MLIR->LLVM lowering for globals requires non-scalar data types. So use a + // dummy size-1 array for the scalar case. + // + // TODO: LLVM translation in MLIR core should handle this case better. + if (type.getRank() == 0) + return LLVMType::getArrayTy(llvmType, 1); + + if (!llvmType) { + rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "cannot convert element type " << type.getElementType() + << " to an LLVM type"; + }); + return nullptr; + } + + // Construct an LLVM nested array type for the tensor initializer. + // tensor -> float + // tensor<10xf32> -> [10 x float] + // tensor<2x3xf32> -> [2 x [3 x float]] + assert(type.hasStaticShape()); + auto shape = type.getShape(); + while (!shape.empty()) { + llvmType = LLVMType::getArrayTy(llvmType, shape.back()); + shape = shape.drop_back(); + } + return llvmType; + } + LLVMTypeConverter &typeConverter; +}; +} // namespace + //===----------------------------------------------------------------------===// // Lowering for module metadata //===----------------------------------------------------------------------===// @@ -443,6 +651,7 @@ class LowerToLLVM : public LowerToLLVMBase { target.addLegalOp(); populateStdToLLVMConversionPatterns(converter, patterns); patterns.insert(context); + patterns.insert(converter); if (failed(applyFullConversion(module, target, patterns))) { return signalPassFailure(); diff --git a/lib/E2E/LowerToNpcomprtABI.cpp b/lib/E2E/LowerToNpcomprtABI.cpp index 8d4112314..5c6e05af4 100644 --- a/lib/E2E/LowerToNpcomprtABI.cpp +++ b/lib/E2E/LowerToNpcomprtABI.cpp @@ -139,6 +139,39 @@ public: }; } // namespace +namespace { +class LowerGlobalOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tcp::GlobalOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.sym_name(), + op.value()); + return success(); + } +}; +} // namespace + +namespace { +class LowerGetGlobalMemrefOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto abiMemrefType = UnrankedMemRefType::get( + op.getType().cast().getElementType(), /*memorySpace=*/0); + auto abiMemref = rewriter.create( + op.getLoc(), abiMemrefType, op.global()); + // Cast back to the original type. + rewriter.replaceOpWithNewOp(op, abiMemref, op.getType()); + return success(); + } +}; +} // namespace + static LogicalResult doDialectConversion(ModuleOp module) { auto *context = module.getContext(); @@ -172,6 +205,14 @@ static LogicalResult doDialectConversion(ModuleOp module) { target.addLegalOp(); target.addLegalOp(); + patterns.insert(context); + target.addIllegalOp(); + target.addLegalOp(); + + patterns.insert(context); + target.addIllegalOp(); + target.addLegalOp(); + return applyPartialConversion(module, target, patterns); } diff --git a/lib/runtime/CompilerDataStructures.h b/lib/runtime/CompilerDataStructures.h new file mode 100644 index 000000000..5eb5a9a98 --- /dev/null +++ b/lib/runtime/CompilerDataStructures.h @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains data structures (which we typically call "descriptors") +// that are emitted by the compiler and must be kept in sync with the compiler +// code that creates them in LowerToLLVM.cpp. +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H +#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H + +#include + +namespace npcomprt { + +// All arguments are packed into this type-erased form for being invoked. See +// LowerToLLVM.cpp for more details. +typedef void ABIFunc(void **, void **); + +struct FuncDescriptor { + // The length of the function name. + std::int32_t nameLen; + // The name of the function, to allow lookup. + const char *name; + // This is a raw function pointer to the function's entry point as + // emitted by the compiler. + ABIFunc *functionPtr; + // The number of inputs to the function. + std::int32_t numInputs; + // The number of outputs of the function. + std::int32_t numOutputs; + // TODO: Add arg/result descriptors and other metadata. + // With those descriptors we can do type and shape checking for each + // argument. +}; + +// The top-level entry point of the module metadata emitted by the +// compiler. Unlike all the other descriptors here, external code does handle +// this type (albeit through an opaque pointer). +struct ModuleDescriptor { + std::int32_t numFuncDescriptors; + FuncDescriptor *functionDescriptors; +}; + +// Static data representing a global variable (together with its shape). +struct GlobalDescriptor { + std::int32_t numExtents; + std::int32_t *extents; + void *data; +}; + +} // namespace npcomprt + +#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H \ No newline at end of file diff --git a/lib/runtime/CompilerRuntime.cpp b/lib/runtime/CompilerRuntime.cpp index 434a8b0c6..b0008f3bb 100644 --- a/lib/runtime/CompilerRuntime.cpp +++ b/lib/runtime/CompilerRuntime.cpp @@ -15,6 +15,7 @@ #include #include +#include "CompilerDataStructures.h" #include "npcomp/runtime/UserAPI.h" using namespace npcomprt; @@ -118,4 +119,12 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank, extents32Buf[i] = extents64[i]; return Tensor::createRaw(ArrayRef(extents32Buf.data(), rank), elementType, data); +} + +extern "C" UnrankedMemref +__npcomp_compiler_rt_get_global(GlobalDescriptor *global) { + auto *descriptor = MemrefDescriptor::create( + ArrayRef(global->extents, global->numExtents), + global->data); + return UnrankedMemref{global->numExtents, descriptor}; } \ No newline at end of file diff --git a/lib/runtime/Runtime.cpp b/lib/runtime/Runtime.cpp index f3bbc8032..453ef378c 100644 --- a/lib/runtime/Runtime.cpp +++ b/lib/runtime/Runtime.cpp @@ -13,6 +13,8 @@ #include #include +#include "CompilerDataStructures.h" + using namespace npcomprt; //===----------------------------------------------------------------------===// @@ -63,39 +65,7 @@ std::int32_t Tensor::getDataByteSize() const { //===----------------------------------------------------------------------===// // Module metadata descriptors. //===----------------------------------------------------------------------===// -// These descriptors are never created by runtime code. They are always -// embedded by the compiler as static data inside the module. -// -// Their definitions need to be kept in sync with the compiler code in -// LowerToLLVM.cpp -// All arguments are packed into this type-erased form for being invoked. See -// LowerToLLVM.cpp for more details. -typedef void ABIFunc(void **, void **); - -namespace { -struct FuncDescriptor { - // The length of the function name. - std::int32_t nameLen; - // The name of the function, to allow lookup. - const char *name; - // This is a raw function pointer to the function's entry point as - // emitted by the compiler. - ABIFunc *functionPtr; - std::int32_t numInputs; - std::int32_t numOutputs; - // TODO: Add arg/result descriptors and other metadata. - // With those descriptors. -}; -} // namespace - -// The top-level entry point of the module metadata emitted by the -// compiler. -struct npcomprt::ModuleDescriptor { - std::int32_t numFuncDescriptors; - // TODO: Update compiler code to emit this as a separate global. - FuncDescriptor *functionDescriptors; -}; //===----------------------------------------------------------------------===// // Module operations. diff --git a/test/Dialect/Npcomprt/invalid.mlir b/test/Dialect/Npcomprt/invalid.mlir index 1e3b480f8..c62716545 100644 --- a/test/Dialect/Npcomprt/invalid.mlir +++ b/test/Dialect/Npcomprt/invalid.mlir @@ -21,3 +21,23 @@ npcomprt.module_metadata { npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32} } func @f() { return } + +// ----- + +npcomprt.global @g dense<0.0> : tensor<2xf32> + +func @f() { + // expected-error @+1 {{must reference a valid npcomprt.global}} + npcomprt.get_global @nonexistent_symbol : memref<*xf32> + return +} + +// ----- + +npcomprt.global @g dense<0.0> : tensor<2xf32> + +func @f() { + // expected-error @+1 {{inconsistent with element type of global}} + npcomprt.get_global @g : memref<*xi8> + return +} diff --git a/test/Dialect/Npcomprt/ops.mlir b/test/Dialect/Npcomprt/ops.mlir index 1f78036c2..569bfc72e 100644 --- a/test/Dialect/Npcomprt/ops.mlir +++ b/test/Dialect/Npcomprt/ops.mlir @@ -12,3 +12,9 @@ func @f(%arg0: !npcomprt.tensor) { return } +// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32> +npcomprt.global @g dense<0.0> : tensor<10xf32> +func @uses_global() { + npcomprt.get_global @g : memref<*xf32> + return +} \ No newline at end of file diff --git a/test/Dialect/TCP/invalid.mlir b/test/Dialect/TCP/invalid.mlir new file mode 100644 index 000000000..81428791d --- /dev/null +++ b/test/Dialect/TCP/invalid.mlir @@ -0,0 +1,31 @@ +// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s + +// ----- + +tcp.global @g dense<0.0> : tensor<2xf32> + +func @f() { + // expected-error @+1 {{must reference a valid symbol}} + tcp.get_global_memref @nonexistent_symbol : memref<3xf32> + return +} + +// ----- + +tcp.global @g dense<0.0> : tensor<2xf32> + +func @f() { + // expected-error @+1 {{inconsistent with shape of global}} + tcp.get_global_memref @g : memref<3xf32> + return +} + +// ----- + +tcp.global @g dense<0.0> : tensor<2xf32> + +func @f() { + // expected-error @+1 {{inconsistent with element type of global}} + tcp.get_global_memref @g : memref<2xi8> + return +} \ No newline at end of file diff --git a/test/Dialect/TCP/ops.mlir b/test/Dialect/TCP/ops.mlir index 96f321a52..90b72457d 100644 --- a/test/Dialect/TCP/ops.mlir +++ b/test/Dialect/TCP/ops.mlir @@ -1,7 +1,11 @@ // RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail +// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32> +tcp.global @foo dense<0.0> : tensor<10xf32> + func @f(%arg0: tensor, %arg1: tensor, %arg2: i32) { // CHECK: tcp.add %0 = "tcp.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = tcp.get_global_memref @foo : memref<10xf32> return } diff --git a/test/E2E/lower-constant-tensors-to-memrefs.mlir b/test/E2E/lower-constant-tensors-to-memrefs.mlir new file mode 100644 index 000000000..85b6c6f9b --- /dev/null +++ b/test/E2E/lower-constant-tensors-to-memrefs.mlir @@ -0,0 +1,61 @@ +// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memrefs <%s | FileCheck %s + +// CHECK-LABEL: module { +// We check the debug name too since we put some effort into making that readable. +// The name isn't load-bearing though. +// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32> +// CHECK: func @basic +func @basic() -> tensor<3x4xf32> { + // CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32> + // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]] + %0 = constant dense<7.0> : tensor<3x4xf32> + // CHECK: return %[[TENSOR]] + return %0 : tensor<3x4xf32> +} + +// CHECK: } + +// ----- + +// CHECK-LABEL: module { + +// Only one global is created. +// CHECK: tcp.global +// CHECK-NOT: tcp.global +func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) { + %0 = constant dense<7.0> : tensor<3x4xf32> + %1 = constant dense<7.0> : tensor<3x4xf32> + // CHECK: return %[[TENSOR]] + return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32> +} + +// CHECK: } + +// ----- + +// CHECK-LABEL: module { + +// Two globals are created. +// CHECK: tcp.global +// CHECK: tcp.global +// CHECK-NOT: tcp.global +func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) { + %0 = constant dense<7.0> : tensor<3x4xf32> + %1 = constant dense<8.0> : tensor<3x4xf32> + // CHECK: return %[[TENSOR]] + return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32> +} + +// CHECK: } + +// ----- + +// CHECK-LABEL: module { +// We don't convert non-tensor globals. +// CHECK-NOT: tcp.global +func @non_tensor() { + %0 = constant 7 : i32 + return +} + +// CHECK: } \ No newline at end of file diff --git a/test/E2E/lower-to-llvm-global.mlir b/test/E2E/lower-to-llvm-global.mlir new file mode 100644 index 000000000..ada3d05b7 --- /dev/null +++ b/test/E2E/lower-to-llvm-global.mlir @@ -0,0 +1,32 @@ +// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail + +// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm<"[3 x float]"> +// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm<"[1 x i32]"> +// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm<"{ i32, i32*, i8* }"> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, i32*, i8* }"> +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, i32*, i8* }"> +// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm<"[1 x i32]*"> +// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x i32]*"> to !llvm<"i32*"> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, i32*, i8* }"> +// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm<"[3 x float]*"> +// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm<"[3 x float]*"> to !llvm<"i8*"> +// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm<"{ i32, i32*, i8* }"> +// CHECK: llvm.return %[[VAL_8]] : !llvm<"{ i32, i32*, i8* }"> +// CHECK: } +// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm<"{ i64, i8* }"> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm<"{ i32, i32*, i8* }*"> +// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm<"{ i32, i32*, i8* }*">) -> !llvm<"{ i64, i8* }"> +npcomprt.global @g dense<7.000000e+00> : tensor<3xf32> +func @calls_get_global() -> memref<*xf32> { + %0 = npcomprt.get_global @g : memref<*xf32> + return %0 : memref<*xf32> +} + +// ----- + +// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy. +// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor) : !llvm<"[1 x float]"> +// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm<"[1 x i32]"> +npcomprt.global @g dense<7.0> : tensor + diff --git a/test/E2E/lower-to-npcomprt-abi.mlir b/test/E2E/lower-to-npcomprt-abi.mlir index 6f6170352..4f468f419 100644 --- a/test/E2E/lower-to-npcomprt-abi.mlir +++ b/test/E2E/lower-to-npcomprt-abi.mlir @@ -37,6 +37,23 @@ func @basic(%arg0: tensor) -> tensor { // ----- + +// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32> +tcp.global @g dense<7.0> : tensor<10xf32> +// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor +func @gets_global() -> tensor<10xf32> { +// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32> +// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32> +// CHECK: %[[RETMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32> +// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[RETMEMREF]] : memref<*xf32> +// CHECK: return %[[RET]] : !npcomprt.tensor + %0 = tcp.get_global_memref @g : memref<10xf32> + %1 = tensor_load %0 : memref<10xf32> + return %1 : tensor<10xf32> +} + +// ----- + // expected-error @+1 {{func not expressible with npcomprt ABI}} func @unhandled_abi_type_on_public_func(%arg0: i32) { return diff --git a/test/npcomp-run-mlir/constant-add-scalar.mlir b/test/npcomp-run-mlir/constant-add-scalar.mlir new file mode 100644 index 000000000..e0cf9a772 --- /dev/null +++ b/test/npcomp-run-mlir/constant-add-scalar.mlir @@ -0,0 +1,12 @@ +// RUN: npcomp-run-mlir -input %s \ +// RUN: -invoke constant_add_scalar \ +// RUN: -arg-value="dense<3.0> : tensor" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s + +// CHECK: output #0: dense<4.000000e+00> : tensor +func @constant_add_scalar(%arg0: tensor) -> tensor { + %0 = constant dense<1.0> : tensor + %1 = "tcf.add"(%arg0, %0) : (tensor, tensor) -> tensor + return %1 : tensor +} \ No newline at end of file diff --git a/test/npcomp-run-mlir/constant-add.mlir b/test/npcomp-run-mlir/constant-add.mlir new file mode 100644 index 000000000..3caf146b7 --- /dev/null +++ b/test/npcomp-run-mlir/constant-add.mlir @@ -0,0 +1,12 @@ +// RUN: npcomp-run-mlir -input %s \ +// RUN: -invoke constant_add \ +// RUN: -arg-value="dense<[3.0, 5.0]> : tensor<2xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s + +// CHECK: output #0: dense<[4.000000e+00, 7.000000e+00]> : tensor<2xf32> +func @constant_add(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = constant dense<[1.0, 2.0]> : tensor<2xf32> + %1 = "tcf.add"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} \ No newline at end of file diff --git a/test/npcomp-run-mlir/constant.mlir b/test/npcomp-run-mlir/constant.mlir new file mode 100644 index 000000000..f17bdc7a8 --- /dev/null +++ b/test/npcomp-run-mlir/constant.mlir @@ -0,0 +1,10 @@ +// RUN: npcomp-run-mlir -input %s \ +// RUN: -invoke constant \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s + +// CHECK: output #0: dense<1.000000e+00> : tensor +func @constant() -> tensor { + %0 = constant dense<1.0> : tensor + return %0 : tensor +} \ No newline at end of file diff --git a/tools/bash_helpers.sh b/tools/bash_helpers.sh index 01527794b..39c5ab56b 100644 --- a/tools/bash_helpers.sh +++ b/tools/bash_helpers.sh @@ -47,3 +47,5 @@ npctall() { # https://superuser.com/q/253068 export FIGNORE=$FIGNORE:nstall-mlir +export PYTHONPATH="$(realpath ${build_dir}/python):$(realpath ${build_dir}/python_native):$(realpath ${build_dir}/iree/bindings/python)" +