LowerRankedShapes: support shape.const_shape op.

Also, the previous code had a special case for deleting this op when it
had no uses. This is subsumed by the change in this commit since now
shape.const_shape is properly lowered.

With this change, the included test case with multiple serially
dependent ops works!
This specific issue was related to the scalar argument to that
function. We needed to compute a broadcast of a scalar shape (which is a
shape.const_shape) with another shape.
pull/1/head
Sean Silva 2020-07-08 20:10:01 -07:00
parent b4f0cea8fa
commit f18014f60c
3 changed files with 50 additions and 7 deletions

View File

@ -18,8 +18,24 @@
using namespace mlir;
using namespace mlir::NPCOMP;
// This has to be a "conversion pattern" since the `operands` argument
// gives access to the post-conversion operands from earlier ops.
namespace {
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto extents = llvm::to_vector<6>(llvm::map_range(
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
}));
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, rewriter.getType<shape::ShapeType>(), extents);
return success();
}
};
} // namespace
namespace {
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
public:
@ -137,6 +153,8 @@ public:
// This is similar to the approach that is used in IREE. It is basically a
// combination of the ConvertShapeToShapex pass and the
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
// These patterns have to be "conversion patterns" since the `operands` argument
// gives access to the post-conversion operands from earlier ops.
//
// This pass depends heavily on ranked shapes, since only ranked shapes can
// be statically expanded to a fixed set of SSA extents.
@ -156,6 +174,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerConstShapeOp>(context);
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
@ -175,7 +194,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
// deleted during conversion because they become unused only after
// subsequent patterns bypass them.
auto walkResult = func.walk([](Operation *op) {
if (!(isa<shape::FromExtentsOp>(op) || isa<shape::ConstShapeOp>(op)))
if (!isa<shape::FromExtentsOp>(op))
return WalkResult::advance();
if (!op->use_empty()) {
op->emitError("could not be eliminated");

View File

@ -20,10 +20,6 @@ func @erase_stray_shape_ops(%arg0: index) {
// CHECK-NOT: shape.from_extents
%0 = shape.from_extents %arg0
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
// CHECK-NOT: tcp.shape_observe_error
// CHECK-NOT: shape.const_shape
%1 = shape.const_shape []
"tcp.shape_observe_error"(%1) : (!shape.shape) -> none
return
}
@ -34,3 +30,16 @@ func @cannot_erase_stray_shape_ops() -> !shape.shape {
%0 = shape.from_extents
return %0 : !shape.shape
}
// -----
// CHECK-LABEL: func @const_shape
func @const_shape() -> index {
// CHECK-NOT: shape.const_shape
%0 = shape.const_shape []
%1 = shape.const_shape [7]
%2 = tcp.get_extent %1, 0
// CHECK: %[[C7:.*]] = constant 7 : index
// CHECK: return %[[C7]]
return %2 : index
}

View File

@ -0,0 +1,15 @@
// RUN: npcomp-run-mlir -input %s \
// RUN: -invoke multiple_ops \
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
// RUN: -arg-value="dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<[
// CHECK-SAME: [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]]> : tensor<2x2xf32>
func @multiple_ops(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tcf.add"(%arg1, %arg2) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}