From 1d3dbd9d5c51b761027dc862b6e8d3167a8ce8df Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 20 May 2020 18:48:53 -0700 Subject: [PATCH] Lower to LLVM dialect. With this commit, we finish conversion to LLVM dialect, and should be ready for subsequent commits to convert to an LLVM module and let LLVM codegen to native machine code. This required a custom "lower to LLVM" pass to support lowering tcp.abort_if to a runtime call. In the future, this pass will grow to do type conversions for our own runtime types as we add those. --- include/npcomp/E2E/E2E.h | 2 + include/npcomp/E2E/Passes.td | 5 +++ lib/E2E/CMakeLists.txt | 3 ++ lib/E2E/E2E.cpp | 24 +++++++++-- lib/E2E/LowerToLLVM.cpp | 79 ++++++++++++++++++++++++++++++++++++ test/E2E/lower-to-llvm.mlir | 15 +++++++ 6 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 lib/E2E/LowerToLLVM.cpp create mode 100644 test/E2E/lower-to-llvm.mlir diff --git a/include/npcomp/E2E/E2E.h b/include/npcomp/E2E/E2E.h index cad944d60..2f854ead9 100644 --- a/include/npcomp/E2E/E2E.h +++ b/include/npcomp/E2E/E2E.h @@ -38,6 +38,8 @@ std::unique_ptr> createLowerToMemRefABIPass(); std::unique_ptr> createLowerAllocMemRefOpsPass(); +std::unique_ptr> createLowerToLLVMPass(); + void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm); // The main pipeline that encapsulates the full E2E lowering. diff --git a/include/npcomp/E2E/Passes.td b/include/npcomp/E2E/Passes.td index 963cb03c1..830548479 100644 --- a/include/npcomp/E2E/Passes.td +++ b/include/npcomp/E2E/Passes.td @@ -53,4 +53,9 @@ def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> { let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()"; } +def LowerToLLVM : Pass<"e2e-lower-to-llvm", "ModuleOp"> { + let summary = "Lower everything to LLVM"; + let constructor = "mlir::NPCOMP::createLowerToLLVMPass();"; +} + #endif // NPCOMP_E2E_PASSES diff --git a/lib/E2E/CMakeLists.txt b/lib/E2E/CMakeLists.txt index a93ac37ad..6140afdbc 100644 --- a/lib/E2E/CMakeLists.txt +++ b/lib/E2E/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(NPCOMPE2E E2E.cpp LowerRankedShapes.cpp LowerToHybridTensorMemRef.cpp + LowerToLLVM.cpp LowerToMemRefABI.cpp ADDITIONAL_HEADER_DIRS @@ -18,4 +19,6 @@ add_mlir_library(NPCOMPE2E MLIRIR MLIRLinalgOps MLIRStandardOps + MLIRStandardToLLVM + MLIRLoopToStandard ) diff --git a/lib/E2E/E2E.cpp b/lib/E2E/E2E.cpp index 102024764..c769e9baf 100644 --- a/lib/E2E/E2E.cpp +++ b/lib/E2E/E2E.cpp @@ -42,6 +42,7 @@ #include "npcomp/E2E/E2E.h" #include "PassDetail.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -379,7 +380,7 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) { pm.addPass(createCSEPass()); // -------------------------------------------------------------------------- - // Lowering down to LLVM + // Preparation for converting to an LLVM module. // -------------------------------------------------------------------------- // Now, we begin the process of lowering to LLVM's level of abstraction // (after which LLVM will take over lowering to machine code). @@ -423,6 +424,23 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) { // pass that checks no !shape.shape types left. pm.addPass(createLowerRankedShapesPass()); - // TODO: - // Convert all of it to LLVM? + + // Run a final canonicalization pass to delete dead + // `tcp.shape_from_extents` ops. + // This is needed for correctness, since we can't currently lower that op + // to LLVM, since we don't have a runtime representation of `!shape.shape`. + // TODO: Change LowerRankedShapes to delete these ops itself. + pm.addPass(createCanonicalizerPass()); + + // -------------------------------------------------------------------------- + // Final conversion to an LLVM module. + // -------------------------------------------------------------------------- + + // Convert scf to std control flow in preparation for going to LLVM. + pm.addPass(createLowerToCFGPass()); + + // Finally, convert to LLVM dialect using our custom LowerToLLVM pass + // which reuses the upstream patterns and gives us a place to add our own + // patterns for any custom ops and types we wish to lower. + pm.addPass(createLowerToLLVMPass()); } diff --git a/lib/E2E/LowerToLLVM.cpp b/lib/E2E/LowerToLLVM.cpp new file mode 100644 index 000000000..ff2671c7e --- /dev/null +++ b/lib/E2E/LowerToLLVM.cpp @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/E2E/E2E.h" +#include "PassDetail.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "npcomp/Dialect/TCP/IR/TCPOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using mlir::LLVM::LLVMType; + +namespace { +class LowerAbortIf : public OpConversionPattern { +public: + LowerAbortIf(LLVM::LLVMFuncOp abortIfFunc) + : OpConversionPattern(abortIfFunc.getContext()), + abortIfFunc(abortIfFunc) {} + LogicalResult + matchAndRewrite(tcp::AbortIfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + tcp::AbortIfOp::OperandAdaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op, abortIfFunc, adaptor.pred()); + return success(); + } + LLVM::LLVMFuncOp abortIfFunc; +}; +} // namespace + +// Create the LLVM function declaration for our runtime function +// that backs the tcp.abort_if op. +LLVM::LLVMFuncOp createAbortIfFuncDecl(ModuleOp module) { + auto *llvmDialect = + module.getContext()->getRegisteredDialect(); + auto abortIfFuncTy = LLVMType::getFunctionTy( + LLVMType::getVoidTy(llvmDialect), {LLVMType::getInt1Ty(llvmDialect)}, + /*isVarArg=*/false); + OpBuilder builder(module.getBodyRegion()); + return builder.create(module.getLoc(), "__npcomp_abort_if", + abortIfFuncTy, + LLVM::Linkage::External); +} + +namespace { +class LowerToLLVM : public LowerToLLVMBase { + void runOnOperation() { + auto module = getOperation(); + auto *context = &getContext(); + + LLVM::LLVMFuncOp abortIfFunc = createAbortIfFuncDecl(module); + + LLVMTypeConverter converter(context); + OwningRewritePatternList patterns; + LLVMConversionTarget target(*context); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addLegalOp(); + populateStdToLLVMConversionPatterns(converter, patterns); + patterns.insert(abortIfFunc); + + if (failed(applyFullConversion(module, target, patterns, &converter))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> mlir::NPCOMP::createLowerToLLVMPass() { + return std::make_unique(); +} diff --git a/test/E2E/lower-to-llvm.mlir b/test/E2E/lower-to-llvm.mlir new file mode 100644 index 000000000..02eaebd94 --- /dev/null +++ b/test/E2E/lower-to-llvm.mlir @@ -0,0 +1,15 @@ +// RUN: npcomp-opt -e2e-lower-to-llvm <%s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: llvm.func @identity(%arg0: !llvm.i64, %arg1: !llvm<"i8*">) -> !llvm<"{ i64, i8* }"> +func @identity(%arg0: memref<*xf32>) -> memref<*xf32> { + return %arg0 : memref<*xf32> +} + +// CHECK-LABEL: llvm.func @abort_if( +// CHECK-SAME: %[[PRED:.*]]: !llvm.i1) +func @abort_if(%arg0: i1) { + // CHECK: llvm.call @__npcomp_abort_if(%arg0) : (!llvm.i1) -> () + "tcp.abort_if"(%arg0) : (i1) -> () + return +} +