mirror of https://github.com/llvm/torch-mlir
Lower to LLVM dialect.
With this commit, we finish conversion to LLVM dialect, and should be ready for subsequent commits to convert to an LLVM module and let LLVM codegen to native machine code. This required a custom "lower to LLVM" pass to support lowering tcp.abort_if to a runtime call. In the future, this pass will grow to do type conversions for our own runtime types as we add those.pull/1/head
parent
be1971c4fc
commit
1d3dbd9d5c
|
@ -38,6 +38,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerToMemRefABIPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
||||||
|
|
||||||
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
// The main pipeline that encapsulates the full E2E lowering.
|
// The main pipeline that encapsulates the full E2E lowering.
|
||||||
|
|
|
@ -53,4 +53,9 @@ def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()";
|
let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LowerToLLVM : Pass<"e2e-lower-to-llvm", "ModuleOp"> {
|
||||||
|
let summary = "Lower everything to LLVM";
|
||||||
|
let constructor = "mlir::NPCOMP::createLowerToLLVMPass();";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_E2E_PASSES
|
#endif // NPCOMP_E2E_PASSES
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_mlir_library(NPCOMPE2E
|
||||||
E2E.cpp
|
E2E.cpp
|
||||||
LowerRankedShapes.cpp
|
LowerRankedShapes.cpp
|
||||||
LowerToHybridTensorMemRef.cpp
|
LowerToHybridTensorMemRef.cpp
|
||||||
|
LowerToLLVM.cpp
|
||||||
LowerToMemRefABI.cpp
|
LowerToMemRefABI.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
@ -18,4 +19,6 @@ add_mlir_library(NPCOMPE2E
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalgOps
|
MLIRLinalgOps
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
|
MLIRStandardToLLVM
|
||||||
|
MLIRLoopToStandard
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,6 +42,7 @@
|
||||||
#include "npcomp/E2E/E2E.h"
|
#include "npcomp/E2E/E2E.h"
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
@ -379,7 +380,7 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
||||||
pm.addPass(createCSEPass());
|
pm.addPass(createCSEPass());
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Lowering down to LLVM
|
// Preparation for converting to an LLVM module.
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Now, we begin the process of lowering to LLVM's level of abstraction
|
// Now, we begin the process of lowering to LLVM's level of abstraction
|
||||||
// (after which LLVM will take over lowering to machine code).
|
// (after which LLVM will take over lowering to machine code).
|
||||||
|
@ -423,6 +424,23 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
||||||
// pass that checks no !shape.shape types left.
|
// pass that checks no !shape.shape types left.
|
||||||
pm.addPass(createLowerRankedShapesPass());
|
pm.addPass(createLowerRankedShapesPass());
|
||||||
|
|
||||||
// TODO:
|
|
||||||
// Convert all of it to LLVM?
|
// Run a final canonicalization pass to delete dead
|
||||||
|
// `tcp.shape_from_extents` ops.
|
||||||
|
// This is needed for correctness, since we can't currently lower that op
|
||||||
|
// to LLVM, since we don't have a runtime representation of `!shape.shape`.
|
||||||
|
// TODO: Change LowerRankedShapes to delete these ops itself.
|
||||||
|
pm.addPass(createCanonicalizerPass());
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Final conversion to an LLVM module.
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Convert scf to std control flow in preparation for going to LLVM.
|
||||||
|
pm.addPass(createLowerToCFGPass());
|
||||||
|
|
||||||
|
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
||||||
|
// which reuses the upstream patterns and gives us a place to add our own
|
||||||
|
// patterns for any custom ops and types we wish to lower.
|
||||||
|
pm.addPass(createLowerToLLVMPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/E2E/E2E.h"
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using mlir::LLVM::LLVMType;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerAbortIf : public OpConversionPattern<tcp::AbortIfOp> {
|
||||||
|
public:
|
||||||
|
LowerAbortIf(LLVM::LLVMFuncOp abortIfFunc)
|
||||||
|
: OpConversionPattern(abortIfFunc.getContext()),
|
||||||
|
abortIfFunc(abortIfFunc) {}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::AbortIfOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
tcp::AbortIfOp::OperandAdaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, abortIfFunc, adaptor.pred());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
LLVM::LLVMFuncOp abortIfFunc;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Create the LLVM function declaration for our runtime function
|
||||||
|
// that backs the tcp.abort_if op.
|
||||||
|
LLVM::LLVMFuncOp createAbortIfFuncDecl(ModuleOp module) {
|
||||||
|
auto *llvmDialect =
|
||||||
|
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
auto abortIfFuncTy = LLVMType::getFunctionTy(
|
||||||
|
LLVMType::getVoidTy(llvmDialect), {LLVMType::getInt1Ty(llvmDialect)},
|
||||||
|
/*isVarArg=*/false);
|
||||||
|
OpBuilder builder(module.getBodyRegion());
|
||||||
|
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "__npcomp_abort_if",
|
||||||
|
abortIfFuncTy,
|
||||||
|
LLVM::Linkage::External);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto module = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
LLVM::LLVMFuncOp abortIfFunc = createAbortIfFuncDecl(module);
|
||||||
|
|
||||||
|
LLVMTypeConverter converter(context);
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
LLVMConversionTarget target(*context);
|
||||||
|
target.addDynamicallyLegalOp<FuncOp>(
|
||||||
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||||
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
|
patterns.insert<LowerAbortIf>(abortIfFunc);
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(module, target, patterns, &converter))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::NPCOMP::createLowerToLLVMPass() {
|
||||||
|
return std::make_unique<LowerToLLVM>();
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
// RUN: npcomp-opt -e2e-lower-to-llvm <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @identity(%arg0: !llvm.i64, %arg1: !llvm<"i8*">) -> !llvm<"{ i64, i8* }">
|
||||||
|
func @identity(%arg0: memref<*xf32>) -> memref<*xf32> {
|
||||||
|
return %arg0 : memref<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @abort_if(
|
||||||
|
// CHECK-SAME: %[[PRED:.*]]: !llvm.i1)
|
||||||
|
func @abort_if(%arg0: i1) {
|
||||||
|
// CHECK: llvm.call @__npcomp_abort_if(%arg0) : (!llvm.i1) -> ()
|
||||||
|
"tcp.abort_if"(%arg0) : (i1) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue