Make some passes run on FuncOp so they can run in parallel.

pull/116/head
Sean Silva 2020-11-13 15:34:24 -08:00
parent 482791fa4a
commit 32388d938b
8 changed files with 16 additions and 19 deletions

View File

@ -42,7 +42,7 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
// TCFToTCP // TCFToTCP
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> { def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "FuncOp"> {
let summary = "Convert TCF to Linalg"; let summary = "Convert TCF to Linalg";
let description = [{ let description = [{
The intention is for this pass to convert mainly to linalg named ops. The intention is for this pass to convert mainly to linalg named ops.
@ -57,7 +57,7 @@ def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
// TCFToStd // TCFToStd
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> { def ConvertTCFToStd : Pass<"convert-tcf-to-std", "FuncOp"> {
let summary = "Convert TCF to Std"; let summary = "Convert TCF to Std";
let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()"; let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()";
} }
@ -66,7 +66,7 @@ def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
// TCFToTCP // TCFToTCP
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> { def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "FuncOp"> {
let summary = "Convert TCF to TCP"; let summary = "Convert TCF to TCP";
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()"; let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
} }

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToLinalgPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToStdPass(); std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToStdPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass(); std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToTCPPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -84,8 +84,7 @@ public:
} }
void runOnOperation() override { void runOnOperation() override {
ModuleOp module = getOperation(); (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
(void)applyPatternsAndFoldGreedily(module, getPatterns());
} }
FrozenRewritePatternList getPatterns() { FrozenRewritePatternList getPatterns() {
@ -97,7 +96,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToLinalgPass() { mlir::NPCOMP::createConvertTCFToLinalgPass() {
return std::make_unique<ConvertTCFToLinalg>(); return std::make_unique<ConvertTCFToLinalg>();
} }

View File

@ -139,8 +139,7 @@ public:
} }
void runOnOperation() override { void runOnOperation() override {
ModuleOp module = getOperation(); (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
(void)applyPatternsAndFoldGreedily(module, getPatterns());
} }
FrozenRewritePatternList getPatterns() { FrozenRewritePatternList getPatterns() {
@ -156,7 +155,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToStdPass() { mlir::NPCOMP::createConvertTCFToStdPass() {
return std::make_unique<ConvertTCFToStd>(); return std::make_unique<ConvertTCFToStd>();
} }

View File

@ -29,8 +29,7 @@ public:
} }
void runOnOperation() override { void runOnOperation() override {
ModuleOp module = getOperation(); (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
(void)applyPatternsAndFoldGreedily(module, getPatterns());
} }
FrozenRewritePatternList getPatterns() { FrozenRewritePatternList getPatterns() {
@ -43,7 +42,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToTCPPass() { mlir::NPCOMP::createConvertTCFToTCPPass() {
return std::make_unique<ConvertTCFToTCP>(); return std::make_unique<ConvertTCFToTCP>();
} }

View File

@ -243,7 +243,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// Run some upstream bufferization passes to finish bufferization. // Run some upstream bufferization passes to finish bufferization.
pm.addNestedPass<FuncOp>(createStdBufferizePass()); pm.addNestedPass<FuncOp>(createStdBufferizePass());
pm.addNestedPass<FuncOp>(createSCFBufferizePass()); pm.addNestedPass<FuncOp>(createSCFBufferizePass());
pm.addPass(createLinalgBufferizePass()); pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addPass(createFuncBufferizePass()); pm.addPass(createFuncBufferizePass());
// TODO: Do buffer deallocation. We should be able to just drop in the // TODO: Do buffer deallocation. We should be able to just drop in the
@ -306,9 +306,9 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
// //
// TCP does not. So we need to reify the broadcasting and error checking. // TCP does not. So we need to reify the broadcasting and error checking.
// These all run at the module level. // These all run at the module level.
pm.addPass(createConvertTCFToStdPass()); pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
pm.addPass(createConvertTCFToLinalgPass()); pm.addNestedPass<FuncOp>(createConvertTCFToLinalgPass());
pm.addPass(createConvertTCFToTCPPass()); pm.addNestedPass<FuncOp>(createConvertTCFToTCPPass());
createRefBackendLoweringPipeline(pm, options); createRefBackendLoweringPipeline(pm, options);
} }