mirror of https://github.com/llvm/torch-mlir
Add `RecomposeComplexOps` declaration + fix typos in pass name (#1950)
The `RecomposeComplexOps` pass currently does not have a TableGen declaration and it is using the base class of `DecomposeComplexOps`, which causes `--mlir-print-ir-after-all` to create wrong pass labels. This commit fixes that as well as some minor typos in the name of the pass.pull/1986/head snapshot-20230329.792
parent
d803ab4eeb
commit
0103c55e55
|
@ -106,7 +106,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||||
|
|
|
@ -244,6 +244,29 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
||||||
|
let summary = "Recompose torch operations that have been decomposed by TorchScript";
|
||||||
|
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
|
||||||
|
let description = [{
|
||||||
|
There are certain ops that TorchScript will split into multiple ops that
|
||||||
|
prevent optimizations in Torch-MLIR from taking place. In this pass such
|
||||||
|
sequences of ops are identified and combined into a higher level op,
|
||||||
|
preserving the original behavior, while allowing new optimizations to happen.
|
||||||
|
|
||||||
|
An example is the handling of the indexing operation in PyTorch. The following
|
||||||
|
|
||||||
|
```
|
||||||
|
input_tensor[1:2, :] = 7
|
||||||
|
```
|
||||||
|
|
||||||
|
will get split into a series of `slice` ops to get the sub-tensor, then an
|
||||||
|
in-place copy to overwrite the sub-tensor with the value 7. This type of
|
||||||
|
pattern prevents the `MaximizeValueSemantics` pass from succeeding. So,
|
||||||
|
using `RecomposeComplexOps`, the series of slices + copy is identified
|
||||||
|
and turned into a single `index_put` operation.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||||
let summary = "Reify shape calculations.";
|
let summary = "Reify shape calculations.";
|
||||||
let constructor = [{
|
let constructor = [{
|
||||||
|
|
|
@ -107,7 +107,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
// Clean up again to avoid needing to to back around the fixed-point
|
// Clean up again to avoid needing to to back around the fixed-point
|
||||||
// iteration.
|
// iteration.
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
|
||||||
// Reduce variants of ops to a smaller set of primitives.
|
// Reduce variants of ops to a smaller set of primitives.
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
createReduceOpVariantsPass(options.extraLibrary));
|
createReduceOpVariantsPass(options.extraLibrary));
|
||||||
|
|
|
@ -74,10 +74,9 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class RecomposeComplexOps
|
class RecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
|
: public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
|
||||||
public:
|
public:
|
||||||
RecomposeComplexOps() = default;
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
@ -98,6 +97,6 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createRecomposeComplexOps() {
|
mlir::torch::Torch::createRecomposeComplexOpsPass() {
|
||||||
return std::make_unique<RecomposeComplexOps>();
|
return std::make_unique<RecomposeComplexOpsPass>();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue