diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index b43d20760..563e42f42 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -42,7 +42,7 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> { // TCFToTCP //===----------------------------------------------------------------------===// -def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> { +def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "FuncOp"> { let summary = "Convert TCF to Linalg"; let description = [{ The intention is for this pass to convert mainly to linalg named ops. @@ -57,7 +57,7 @@ def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> { // TCFToStd //===----------------------------------------------------------------------===// -def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> { +def ConvertTCFToStd : Pass<"convert-tcf-to-std", "FuncOp"> { let summary = "Convert TCF to Std"; let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()"; } @@ -66,7 +66,7 @@ def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> { // TCFToTCP //===----------------------------------------------------------------------===// -def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> { +def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "FuncOp"> { let summary = "Convert TCF to TCP"; let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()"; } diff --git a/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h index 1bbfc588c..8c4eca694 100644 --- a/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h +++ b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h @@ -14,7 +14,7 @@ namespace mlir { namespace NPCOMP { -std::unique_ptr> createConvertTCFToLinalgPass(); +std::unique_ptr> createConvertTCFToLinalgPass(); } } // namespace mlir diff --git a/include/npcomp/Conversion/TCFToStd/TCFToStd.h b/include/npcomp/Conversion/TCFToStd/TCFToStd.h index 96588d0b8..1a0095522 100644 --- a/include/npcomp/Conversion/TCFToStd/TCFToStd.h +++ b/include/npcomp/Conversion/TCFToStd/TCFToStd.h @@ -14,7 +14,7 @@ namespace mlir { namespace NPCOMP { -std::unique_ptr> createConvertTCFToStdPass(); +std::unique_ptr> createConvertTCFToStdPass(); } } // namespace mlir diff --git a/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h b/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h index 1947b74ba..d93049d1c 100644 --- a/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h +++ b/include/npcomp/Conversion/TCFToTCP/TCFToTCP.h @@ -14,7 +14,7 @@ namespace mlir { namespace NPCOMP { -std::unique_ptr> createConvertTCFToTCPPass(); +std::unique_ptr> createConvertTCFToTCPPass(); } } // namespace mlir diff --git a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp index 35c9f54fa..2494744a3 100644 --- a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp +++ b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp @@ -84,8 +84,7 @@ public: } void runOnOperation() override { - ModuleOp module = getOperation(); - (void)applyPatternsAndFoldGreedily(module, getPatterns()); + (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); } FrozenRewritePatternList getPatterns() { @@ -97,7 +96,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::NPCOMP::createConvertTCFToLinalgPass() { return std::make_unique(); } diff --git a/lib/Conversion/TCFToStd/TCFToStd.cpp b/lib/Conversion/TCFToStd/TCFToStd.cpp index 9ddb89743..506f41512 100644 --- a/lib/Conversion/TCFToStd/TCFToStd.cpp +++ b/lib/Conversion/TCFToStd/TCFToStd.cpp @@ -139,8 +139,7 @@ public: } void runOnOperation() override { - ModuleOp module = getOperation(); - (void)applyPatternsAndFoldGreedily(module, getPatterns()); + (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); } FrozenRewritePatternList getPatterns() { @@ -156,7 +155,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::NPCOMP::createConvertTCFToStdPass() { return std::make_unique(); } diff --git a/lib/Conversion/TCFToTCP/TCFToTCP.cpp b/lib/Conversion/TCFToTCP/TCFToTCP.cpp index dba2dd0c3..334ffc7dd 100644 --- a/lib/Conversion/TCFToTCP/TCFToTCP.cpp +++ b/lib/Conversion/TCFToTCP/TCFToTCP.cpp @@ -29,8 +29,7 @@ public: } void runOnOperation() override { - ModuleOp module = getOperation(); - (void)applyPatternsAndFoldGreedily(module, getPatterns()); + (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); } FrozenRewritePatternList getPatterns() { @@ -43,7 +42,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::NPCOMP::createConvertTCFToTCPPass() { return std::make_unique(); } diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 67498e1f7..f2e10dd2f 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -243,7 +243,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // Run some upstream bufferization passes to finish bufferization. pm.addNestedPass(createStdBufferizePass()); pm.addNestedPass(createSCFBufferizePass()); - pm.addPass(createLinalgBufferizePass()); + pm.addNestedPass(createLinalgBufferizePass()); pm.addPass(createFuncBufferizePass()); // TODO: Do buffer deallocation. We should be able to just drop in the @@ -306,9 +306,9 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline( // // TCP does not. So we need to reify the broadcasting and error checking. // These all run at the module level. - pm.addPass(createConvertTCFToStdPass()); - pm.addPass(createConvertTCFToLinalgPass()); - pm.addPass(createConvertTCFToTCPPass()); + pm.addNestedPass(createConvertTCFToStdPass()); + pm.addNestedPass(createConvertTCFToLinalgPass()); + pm.addNestedPass(createConvertTCFToTCPPass()); createRefBackendLoweringPipeline(pm, options); }