mirror of https://github.com/llvm/torch-mlir
Make some passes run on FuncOp so they can run in parallel.
parent
482791fa4a
commit
32388d938b
|
@ -42,7 +42,7 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
|
||||||
// TCFToTCP
|
// TCFToTCP
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
|
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "FuncOp"> {
|
||||||
let summary = "Convert TCF to Linalg";
|
let summary = "Convert TCF to Linalg";
|
||||||
let description = [{
|
let description = [{
|
||||||
The intention is for this pass to convert mainly to linalg named ops.
|
The intention is for this pass to convert mainly to linalg named ops.
|
||||||
|
@ -57,7 +57,7 @@ def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
|
||||||
// TCFToStd
|
// TCFToStd
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
|
def ConvertTCFToStd : Pass<"convert-tcf-to-std", "FuncOp"> {
|
||||||
let summary = "Convert TCF to Std";
|
let summary = "Convert TCF to Std";
|
||||||
let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()";
|
let constructor = "mlir::NPCOMP::createConvertTCFToStdPass()";
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ def ConvertTCFToStd : Pass<"convert-tcf-to-std", "ModuleOp"> {
|
||||||
// TCFToTCP
|
// TCFToTCP
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "FuncOp"> {
|
||||||
let summary = "Convert TCF to TCP";
|
let summary = "Convert TCF to TCP";
|
||||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass();
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToLinalgPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToStdPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToTCPPass();
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTCFToTCPPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -84,8 +84,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp module = getOperation();
|
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FrozenRewritePatternList getPatterns() {
|
FrozenRewritePatternList getPatterns() {
|
||||||
|
@ -97,7 +96,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createConvertTCFToLinalgPass() {
|
mlir::NPCOMP::createConvertTCFToLinalgPass() {
|
||||||
return std::make_unique<ConvertTCFToLinalg>();
|
return std::make_unique<ConvertTCFToLinalg>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,8 +139,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp module = getOperation();
|
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FrozenRewritePatternList getPatterns() {
|
FrozenRewritePatternList getPatterns() {
|
||||||
|
@ -156,7 +155,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createConvertTCFToStdPass() {
|
mlir::NPCOMP::createConvertTCFToStdPass() {
|
||||||
return std::make_unique<ConvertTCFToStd>();
|
return std::make_unique<ConvertTCFToStd>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,8 +29,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp module = getOperation();
|
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||||
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FrozenRewritePatternList getPatterns() {
|
FrozenRewritePatternList getPatterns() {
|
||||||
|
@ -43,7 +42,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::createConvertTCFToTCPPass() {
|
mlir::NPCOMP::createConvertTCFToTCPPass() {
|
||||||
return std::make_unique<ConvertTCFToTCP>();
|
return std::make_unique<ConvertTCFToTCP>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -243,7 +243,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
||||||
// Run some upstream bufferization passes to finish bufferization.
|
// Run some upstream bufferization passes to finish bufferization.
|
||||||
pm.addNestedPass<FuncOp>(createStdBufferizePass());
|
pm.addNestedPass<FuncOp>(createStdBufferizePass());
|
||||||
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
|
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
|
||||||
pm.addPass(createLinalgBufferizePass());
|
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||||
pm.addPass(createFuncBufferizePass());
|
pm.addPass(createFuncBufferizePass());
|
||||||
|
|
||||||
// TODO: Do buffer deallocation. We should be able to just drop in the
|
// TODO: Do buffer deallocation. We should be able to just drop in the
|
||||||
|
@ -306,9 +306,9 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
|
||||||
//
|
//
|
||||||
// TCP does not. So we need to reify the broadcasting and error checking.
|
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||||
// These all run at the module level.
|
// These all run at the module level.
|
||||||
pm.addPass(createConvertTCFToStdPass());
|
pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
|
||||||
pm.addPass(createConvertTCFToLinalgPass());
|
pm.addNestedPass<FuncOp>(createConvertTCFToLinalgPass());
|
||||||
pm.addPass(createConvertTCFToTCPPass());
|
pm.addNestedPass<FuncOp>(createConvertTCFToTCPPass());
|
||||||
|
|
||||||
createRefBackendLoweringPipeline(pm, options);
|
createRefBackendLoweringPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue