mirror of https://github.com/llvm/torch-mlir
Add `RecomposeComplexOps` declaration + fix typos in pass name (#1950)
The `RecomposeComplexOps` pass currently does not have a TableGen declaration and it is using the base class of `DecomposeComplexOps`, which causes `--mlir-print-ir-after-all` to create wrong pass labels. This commit fixes that as well as some minor typos in the name of the pass.pull/1986/head snapshot-20230329.792
parent
d803ab4eeb
commit
0103c55e55
|
@ -106,7 +106,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||
|
|
|
@ -244,6 +244,29 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
||||
let summary = "Recompose torch operations that have been decomposed by TorchScript";
|
||||
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
|
||||
let description = [{
|
||||
There are certain ops that TorchScript will split into multiple ops that
|
||||
prevent optimizations in Torch-MLIR from taking place. In this pass such
|
||||
sequences of ops are identified and combined into a higher level op,
|
||||
preserving the original behavior, while allowing new optimizations to happen.
|
||||
|
||||
An example is the handling of the indexing operation in PyTorch. The following
|
||||
|
||||
```
|
||||
input_tensor[1:2, :] = 7
|
||||
```
|
||||
|
||||
will get split into a series of `slice` ops to get the sub-tensor, then an
|
||||
in-place copy to overwrite the sub-tensor with the value 7. This type of
|
||||
pattern prevents the `MaximizeValueSemantics` pass from succeeding. So,
|
||||
using `RecomposeComplexOps`, the series of slices + copy is identified
|
||||
and turned into a single `index_put` operation.
|
||||
}];
|
||||
}
|
||||
|
||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||
let summary = "Reify shape calculations.";
|
||||
let constructor = [{
|
||||
|
|
|
@ -107,7 +107,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// Clean up again to avoid needing to to back around the fixed-point
|
||||
// iteration.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
createReduceOpVariantsPass(options.extraLibrary));
|
||||
|
|
|
@ -74,10 +74,9 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeComplexOps
|
||||
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
|
||||
class RecomposeComplexOpsPass
|
||||
: public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
|
||||
public:
|
||||
RecomposeComplexOps() = default;
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
|
@ -98,6 +97,6 @@ public:
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createRecomposeComplexOps() {
|
||||
return std::make_unique<RecomposeComplexOps>();
|
||||
mlir::torch::Torch::createRecomposeComplexOpsPass() {
|
||||
return std::make_unique<RecomposeComplexOpsPass>();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue