Make some passes run on FuncOp so they can run in parallel.

pull/116/head
Sean Silva 2020-11-13 15:34:24 -08:00
parent 482791fa4a
commit 32388d938b
8 changed files with 16 additions and 19 deletions

View File

@ -42,7 +42,7 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
// TCFToTCP
//===----------------------------------------------------------------------===//
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "FuncOp"> {
let summary = "Convert TCF to Linalg";
let description = [{
The intention is for this pass to convert mainly to linalg named ops.
@ -57,7 +57,7 @@ def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
// TCFToStd
//===----------------------------------------------------------------------===//
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "FuncOp"> {
let summary = "Convert TCF to Std";
let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()";
}
@ -66,7 +66,7 @@ def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
// TCFToTCP
//===----------------------------------------------------------------------===//
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "FuncOp"> {
let summary = "Convert TCF to TCP";
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
}

View File

@ -14,7 +14,7 @@
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass();
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToLinalgPass();
}
} // namespace mlir

View File

@ -14,7 +14,7 @@
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToStdPass();
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToStdPass();
}
} // namespace mlir

View File

@ -14,7 +14,7 @@
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToTCPPass();
}
} // namespace mlir

View File

@ -84,8 +84,7 @@ public:
}
void runOnOperation() override {
ModuleOp module = getOperation();
(void)applyPatternsAndFoldGreedily(module, getPatterns());
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
}
FrozenRewritePatternList getPatterns() {
@ -97,7 +96,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToLinalgPass() {
return std::make_unique<ConvertTCFToLinalg>();
}

View File

@ -139,8 +139,7 @@ public:
}
void runOnOperation() override {
ModuleOp module = getOperation();
(void)applyPatternsAndFoldGreedily(module, getPatterns());
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
}
FrozenRewritePatternList getPatterns() {
@ -156,7 +155,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToStdPass() {
return std::make_unique<ConvertTCFToStd>();
}

View File

@ -29,8 +29,7 @@ public:
}
void runOnOperation() override {
ModuleOp module = getOperation();
(void)applyPatternsAndFoldGreedily(module, getPatterns());
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
}
FrozenRewritePatternList getPatterns() {
@ -43,7 +42,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTCFToTCPPass() {
return std::make_unique<ConvertTCFToTCP>();
}

View File

@ -243,7 +243,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// Run some upstream bufferization passes to finish bufferization.
pm.addNestedPass<FuncOp>(createStdBufferizePass());
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
pm.addPass(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addPass(createFuncBufferizePass());
// TODO: Do buffer deallocation. We should be able to just drop in the
@ -306,9 +306,9 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
//
// TCP does not. So we need to reify the broadcasting and error checking.
// These all run at the module level.
pm.addPass(createConvertTCFToStdPass());
pm.addPass(createConvertTCFToLinalgPass());
pm.addPass(createConvertTCFToTCPPass());
pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
pm.addNestedPass<FuncOp>(createConvertTCFToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTCFToTCPPass());
createRefBackendLoweringPipeline(pm, options);
}