mirror of https://github.com/llvm/torch-mlir
LowerRankedShapes: support shape.const_shape op.
Also, the previous code had a special case for deleting this op when it had no uses. This is subsumed by the change in this commit since now shape.const_shape is properly lowered. With this change, the included test case with multiple serially dependent ops works! This specific issue was related to the scalar argument to that function. We needed to compute a broadcast of a scalar shape (which is a shape.const_shape) with another shape.pull/1/head
parent
b4f0cea8fa
commit
f18014f60c
|
@ -18,8 +18,24 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// This has to be a "conversion pattern" since the `operands` argument
|
||||
// gives access to the post-conversion operands from earlier ops.
|
||||
namespace {
|
||||
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto extents = llvm::to_vector<6>(llvm::map_range(
|
||||
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
|
||||
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
|
||||
}));
|
||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
||||
op, rewriter.getType<shape::ShapeType>(), extents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
||||
public:
|
||||
|
@ -137,6 +153,8 @@ public:
|
|||
// This is similar to the approach that is used in IREE. It is basically a
|
||||
// combination of the ConvertShapeToShapex pass and the
|
||||
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
|
||||
// These patterns have to be "conversion patterns" since the `operands` argument
|
||||
// gives access to the post-conversion operands from earlier ops.
|
||||
//
|
||||
// This pass depends heavily on ranked shapes, since only ranked shapes can
|
||||
// be statically expanded to a fixed set of SSA extents.
|
||||
|
@ -156,6 +174,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerConstShapeOp>(context);
|
||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
||||
patterns.insert<EraseShapeObserveErrorOp>(context);
|
||||
|
@ -175,7 +194,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
// deleted during conversion because they become unused only after
|
||||
// subsequent patterns bypass them.
|
||||
auto walkResult = func.walk([](Operation *op) {
|
||||
if (!(isa<shape::FromExtentsOp>(op) || isa<shape::ConstShapeOp>(op)))
|
||||
if (!isa<shape::FromExtentsOp>(op))
|
||||
return WalkResult::advance();
|
||||
if (!op->use_empty()) {
|
||||
op->emitError("could not be eliminated");
|
||||
|
|
|
@ -20,10 +20,6 @@ func @erase_stray_shape_ops(%arg0: index) {
|
|||
// CHECK-NOT: shape.from_extents
|
||||
%0 = shape.from_extents %arg0
|
||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||
// CHECK-NOT: tcp.shape_observe_error
|
||||
// CHECK-NOT: shape.const_shape
|
||||
%1 = shape.const_shape []
|
||||
"tcp.shape_observe_error"(%1) : (!shape.shape) -> none
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -34,3 +30,16 @@ func @cannot_erase_stray_shape_ops() -> !shape.shape {
|
|||
%0 = shape.from_extents
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_shape
|
||||
func @const_shape() -> index {
|
||||
// CHECK-NOT: shape.const_shape
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [7]
|
||||
%2 = tcp.get_extent %1, 0
|
||||
// CHECK: %[[C7:.*]] = constant 7 : index
|
||||
// CHECK: return %[[C7]]
|
||||
return %2 : index
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: npcomp-run-mlir -input %s \
|
||||
// RUN: -invoke multiple_ops \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
||||
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
|
||||
// RUN: -arg-value="dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<[
|
||||
// CHECK-SAME: [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]]> : tensor<2x2xf32>
|
||||
func @multiple_ops(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = "tcf.add"(%arg1, %arg2) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
Loading…
Reference in New Issue