Add `RecomposeComplexOps` declaration + fix typos in pass name (#1950)

The `RecomposeComplexOps` pass currently does not have a TableGen
declaration and it is using the base class of `DecomposeComplexOps`,
which causes `--mlir-print-ir-after-all` to create wrong pass
labels. This commit fixes that as well as some minor typos in the name
of the pass.
pull/1986/head snapshot-20230329.792
Ramiro Leal-Cavazos 2023-03-28 11:07:47 -07:00 committed by GitHub
parent d803ab4eeb
commit 0103c55e55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 7 deletions

View File

@ -106,7 +106,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps); createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps(); std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary); createReifyShapeCalculationsPass(StringRef extraLibrary);

View File

@ -244,6 +244,29 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
}]; }];
} }
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
let summary = "Recompose torch operations that have been decomposed by TorchScript";
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
let description = [{
There are certain ops that TorchScript will split into multiple ops that
prevent optimizations in Torch-MLIR from taking place. In this pass such
sequences of ops are identified and combined into a higher level op,
preserving the original behavior, while allowing new optimizations to happen.
An example is the handling of the indexing operation in PyTorch. The following
```
input_tensor[1:2, :] = 7
```
will get split into a series of `slice` ops to get the sub-tensor, then an
in-place copy to overwrite the sub-tensor with the value 7. This type of
pattern prevents the `MaximizeValueSemantics` pass from succeeding. So,
using `RecomposeComplexOps`, the series of slices + copy is identified
and turned into a single `index_put` operation.
}];
}
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations."; let summary = "Reify shape calculations.";
let constructor = [{ let constructor = [{

View File

@ -107,7 +107,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point // Clean up again to avoid needing to to back around the fixed-point
// iteration. // iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps()); pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
// Reduce variants of ops to a smaller set of primitives. // Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary)); createReduceOpVariantsPass(options.extraLibrary));

View File

@ -74,10 +74,9 @@ public:
} // namespace } // namespace
namespace { namespace {
class RecomposeComplexOps class RecomposeComplexOpsPass
: public DecomposeComplexOpsBase<RecomposeComplexOps> { : public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
public: public:
RecomposeComplexOps() = default;
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
@ -98,6 +97,6 @@ public:
} // namespace } // namespace
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() { mlir::torch::Torch::createRecomposeComplexOpsPass() {
return std::make_unique<RecomposeComplexOps>(); return std::make_unique<RecomposeComplexOpsPass>();
} }