Add ability to run without optimizations.

The default is to only do the bare minimum needed for correctness, since
that stresses the layering of the system maximally.
pull/1/head
Sean Silva 2020-06-01 19:30:13 -07:00
parent e8b1a07ef4
commit 7b9f0c3364
6 changed files with 27 additions and 12 deletions

View File

@ -42,8 +42,17 @@ std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
struct E2ELoweringPipelineOptions
: public PassPipelineOptions<E2ELoweringPipelineOptions> {
// If this option is true, then perform optimizations.
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(false)};
};
// The main pipeline that encapsulates the full E2E lowering.
void createE2ELoweringPipeline(OpPassManager &pm);
void createE2ELoweringPipeline(OpPassManager &pm,
const E2ELoweringPipelineOptions &options);
} // namespace NPCOMP
} // namespace mlir

View File

@ -35,9 +35,9 @@ inline void registerAllPasses() {
using mlir::Pass; // The .inc files reference this unqualified.
#define GEN_PASS_REGISTRATION
#include "npcomp/E2E/Passes.h.inc"
mlir::PassPipelineRegistration<>("e2e-lowering-pipeline",
"E2E lowering pipeline.",
mlir::NPCOMP::createE2ELoweringPipeline);
mlir::PassPipelineRegistration<E2ELoweringPipelineOptions>(
"e2e-lowering-pipeline", "E2E lowering pipeline.",
mlir::NPCOMP::createE2ELoweringPipeline);
mlir::PassPipelineRegistration<>(
"lower-to-hybrid-tensor-memref-pipeline",
"Pipeline lowering to hybrid tensor/memref.",

View File

@ -293,7 +293,8 @@ mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
// createE2ELoweringPipeline
//===----------------------------------------------------------------------===//
void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
void mlir::NPCOMP::createE2ELoweringPipeline(
OpPassManager &pm, const E2ELoweringPipelineOptions &options) {
// Input IR is TCF ops.
// Convert to TCP.
@ -376,8 +377,10 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
// At this point, we have loose shape calculations floating around, so
// it's a good time to do some general cleanups.
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
// --------------------------------------------------------------------------
// Preparation for converting to an LLVM module.
@ -425,10 +428,10 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
pm.addPass(createLowerRankedShapesPass());
// Run a some final cleanups.
// These are optimizations and not needed for correctness.
// TODO: Add tests that they aren't needed for correctness.
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
// --------------------------------------------------------------------------
// Final conversion to an LLVM module.

View File

@ -1,4 +1,5 @@
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline{optimize} | FileCheck %s --dump-input=fail
// This is the simplest case, which is easy to stare at for debugging
// purposes.

View File

@ -1,4 +1,5 @@
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline{optimize} | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @rank1
func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {

View File

@ -145,7 +145,8 @@ Error compileAndRun(std::string mlirFile, std::string invokeFunction,
PassManager pm(&context, /*verifyPasses=*/true);
applyPassManagerCLOptions(pm);
NPCOMP::createE2ELoweringPipeline(pm);
NPCOMP::E2ELoweringPipelineOptions options;
NPCOMP::createE2ELoweringPipeline(pm, options);
llvm::errs() << "RUNNING PIPELINE: ";
pm.printAsTextualPipeline(llvm::errs());
llvm::errs() << "\n";