Adjust pass pipeline for changes to `dim` canonicalization.

This results in cleaner IR. In particular, Mlp2LayerModule e2e test has
a dim op that is eliminated by this change:
https://gist.github.com/silvasean/734f11a291ae6236c955f65cffae285f
pull/233/head
Sean Silva 2021-06-17 14:57:31 -07:00
parent 1bc889130d
commit 40369c54dc
2 changed files with 7 additions and 0 deletions

View File

@ -25,4 +25,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
NPCOMPTorchDialect NPCOMPTorchDialect
NPCOMPTorchToLinalg NPCOMPTorchToLinalg
NPCOMPInterfaces NPCOMPInterfaces
MLIRMemRefTransforms
) )

View File

@ -8,6 +8,7 @@
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h" #include "npcomp/Backend/Common/Passes.h"
@ -157,6 +158,11 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
if (options.optimize) { if (options.optimize) {
// Clean up any non-canonical code introduced in our linalg lowering. // Clean up any non-canonical code introduced in our linalg lowering.
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass());
} }
// Finish the type conversion from !torch.vtensor to the builtin tensor type. // Finish the type conversion from !torch.vtensor to the builtin tensor type.