mirror of https://github.com/llvm/torch-mlir
Make some passes run on FuncOp so they can run in parallel.
parent
482791fa4a
commit
32388d938b
|
@ -42,7 +42,7 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
|
|||
// TCFToTCP
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
|
||||
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "FuncOp"> {
|
||||
let summary = "Convert TCF to Linalg";
|
||||
let description = [{
|
||||
The intention is for this pass to convert mainly to linalg named ops.
|
||||
|
@ -57,7 +57,7 @@ def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
|
|||
// TCFToStd
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
|
||||
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "FuncOp"> {
|
||||
let summary = "Convert TCF to Std";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()";
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
|
|||
// TCFToTCP
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "FuncOp"> {
|
||||
let summary = "Convert TCF to TCP";
|
||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToLinalgPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToStdPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToStdPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToTCPPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -84,8 +84,7 @@ public:
|
|||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
|
@ -97,7 +96,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertTCFToLinalgPass() {
|
||||
return std::make_unique<ConvertTCFToLinalg>();
|
||||
}
|
||||
|
|
|
@ -139,8 +139,7 @@ public:
|
|||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
|
@ -156,7 +155,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertTCFToStdPass() {
|
||||
return std::make_unique<ConvertTCFToStd>();
|
||||
}
|
||||
|
|
|
@ -29,8 +29,7 @@ public:
|
|||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternList getPatterns() {
|
||||
|
@ -43,7 +42,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertTCFToTCPPass() {
|
||||
return std::make_unique<ConvertTCFToTCP>();
|
||||
}
|
||||
|
|
|
@ -243,7 +243,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// Run some upstream bufferization passes to finish bufferization.
|
||||
pm.addNestedPass<FuncOp>(createStdBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
|
||||
pm.addPass(createLinalgBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||
pm.addPass(createFuncBufferizePass());
|
||||
|
||||
// TODO: Do buffer deallocation. We should be able to just drop in the
|
||||
|
@ -306,9 +306,9 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
|
|||
//
|
||||
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||
// These all run at the module level.
|
||||
pm.addPass(createConvertTCFToStdPass());
|
||||
pm.addPass(createConvertTCFToLinalgPass());
|
||||
pm.addPass(createConvertTCFToTCPPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTCFToLinalgPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTCFToTCPPass());
|
||||
|
||||
createRefBackendLoweringPipeline(pm, options);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue