mirror of https://github.com/llvm/torch-mlir
Adjust pass pipeline for changes to `dim` canonicalization.
This results in cleaner IR. In particular, Mlp2LayerModule e2e test has a dim op that is eliminated by this change: https://gist.github.com/silvasean/734f11a291ae6236c955f65cffae285fpull/233/head
parent
1bc889130d
commit
40369c54dc
|
@ -25,4 +25,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
NPCOMPTorchDialect
|
NPCOMPTorchDialect
|
||||||
NPCOMPTorchToLinalg
|
NPCOMPTorchToLinalg
|
||||||
NPCOMPInterfaces
|
NPCOMPInterfaces
|
||||||
|
MLIRMemRefTransforms
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "npcomp/Backend/Common/Passes.h"
|
#include "npcomp/Backend/Common/Passes.h"
|
||||||
|
@ -157,6 +158,11 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Clean up any non-canonical code introduced in our linalg lowering.
|
// Clean up any non-canonical code introduced in our linalg lowering.
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
|
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||||
|
// dialect for some reason -- we don't have memrefs at this level).
|
||||||
|
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish the type conversion from !torch.vtensor to the builtin tensor type.
|
// Finish the type conversion from !torch.vtensor to the builtin tensor type.
|
||||||
|
|
Loading…
Reference in New Issue