mirror of https://github.com/llvm/torch-mlir
Make LowerRankedShapes clean up shape.from_extents ops.
We were previously relying on a later canonicalization pass to clean them up, but it is a cleaner invariant if the pass gets rid of them itself.pull/1/head
parent
67b129af7a
commit
e7b5a2b8a3
|
@ -424,12 +424,11 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
|||
// pass that checks no !shape.shape types left.
|
||||
pm.addPass(createLowerRankedShapesPass());
|
||||
|
||||
// Run a final canonicalization pass to delete dead
|
||||
// `shape.from_extents` ops.
|
||||
// This is needed for correctness, since we can't currently lower that op
|
||||
// to LLVM, since we don't have a runtime representation of `!shape.shape`.
|
||||
// TODO: Change LowerRankedShapes to delete these ops itself.
|
||||
// Run a some final cleanups.
|
||||
// These are optimizations and not needed for correctness.
|
||||
// TODO: Add tests that they aren't needed for correctness.
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Final conversion to an LLVM module.
|
||||
|
|
|
@ -165,6 +165,20 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Erase all shape::FromExtentsOp's from the program. They can't be
|
||||
// deleted during conversion because they become unused only after
|
||||
// subsequent patterns bypass them.
|
||||
auto walkResult = func.walk([](shape::FromExtentsOp op) {
|
||||
if (!op.use_empty()) {
|
||||
op.emitError("could not be eliminated");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted())
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
// RUN: npcomp-opt -lower-ranked-shapes <%s | FileCheck %s --dump-input=fail
|
||||
// RUN: npcomp-opt -lower-ranked-shapes <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// CHECK-LABEL: func @broadcast_rank2_rank1
|
||||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
||||
// CHECK-NOT: shape.broadcast
|
||||
// CHECK-NOT: tcp.get_extent
|
||||
// CHECK-NOT: shape.from_extents
|
||||
%0 = shape.from_extents %arg0, %arg1
|
||||
%1 = shape.from_extents %arg2
|
||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
|
@ -15,8 +16,17 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
|
|||
|
||||
// CHECK-LABEL: func @erase_shape_observe_error
|
||||
func @erase_shape_observe_error(%arg0: index) {
|
||||
// CHECK-NOT tcp.shape_observe_error
|
||||
// CHECK-NOT: tcp.shape_observe_error
|
||||
// CHECK-NOT: shape.from_extents
|
||||
%0 = shape.from_extents %arg0
|
||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cannot_erase_shape_from_extents() -> !shape.shape {
|
||||
// expected-error @+1 {{could not be eliminated}}
|
||||
%0 = shape.from_extents
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue