Rework reference shape lowering based on upstream shape dialect changes.

* Primarily, the upstream shape dialect now uses tensor<?xindex> for non-erroring, immediate shape calculations (and will return this for shape_of of a tensor or memref).
* In addition, upstream passes do not yet exist for fully lowering to standard ops, so the passes here need to be extended to handle this new convention.
* This should be seen as an intermediate state, necessary to integrate a new LLVM version and needs more work and cleanup for generality.
* There is a good deal of awkwardness in these conversions. The hope is that additional upstream work will yield better defined conversion paths once out of this intermediate state.
pull/2/head
Stella Laurenzo 2020-08-02 22:06:12 -07:00
parent 624d4c6c50
commit fc484d1bd8
14 changed files with 210 additions and 138 deletions

View File

@ -39,7 +39,7 @@ Broadcasts `operand` to the shape `shape`.
It is undefined behavior if such a broadcast is not legal.
}];
let arguments = (ins AnyRankedTensor:$operand, Shape_ShapeType:$shape);
let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$shape);
let results = (outs AnyRankedTensor:$result);
}
@ -54,7 +54,7 @@ def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
let description = [{
Allocates a memref of the given shape.
}];
let arguments = (ins Shape_ShapeType:$shape);
let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
@ -102,36 +102,10 @@ def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error", []> {
effecting ops, is not very well-defined, and needs to be worked
on/redesigned.
}];
let arguments = (ins Shape_ShapeType:$shape);
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
// TODO: ODS seems to create redeclared class members if we remove this,
// resulting in C++ compilation errors.
let results = (outs NoneType:$dummy);
}
// TODO: This probably belongs in the shape dialect.
def TCP_GetExtentOp : TCP_Op<"get_extent", [NoSideEffect]> {
let summary = "Gets the specified extent from a shape.";
let description = [{
Gets the specified extent from a shape.
This op has undefined behavior if the shape is an error.
}];
let arguments = (ins Shape_ShapeType:$shape, I64Attr:$dim);
let results = (outs Index:$extent);
let assemblyFormat = "$shape `,` $dim attr-dict";
let builders = [
// Helper to pass a simple integer instead of an integer attr.
OpBuilder<
[{
OpBuilder &builder, OperationState &result,
Value shape, int64_t dim
}],
[{
build(builder, result, shape, builder.getI64IntegerAttr(dim));
}]
>
];
}
#endif // TCP_OPS

View File

@ -19,6 +19,12 @@ using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
RankedTensorType getExtentTensorType(Builder &builder) {
return RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
}
class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
public:
using OpRewritePattern::OpRewritePattern;
@ -34,6 +40,9 @@ public:
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
Value broadcastedExtents = rewriter.create<shape::ToExtentTensorOp>(
op.getLoc(), getExtentTensorType(rewriter), broadcastedShape);
// TODO: It's annoying to do the dynamic broadcast above then
// do the static transfer function here. Would be nice if they could
// somehow be unified.
@ -43,9 +52,9 @@ public:
auto resultType =
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
op.getLoc(), resultType, op.lhs(), broadcastedShape);
op.getLoc(), resultType, op.lhs(), broadcastedExtents);
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
op.getLoc(), resultType, op.rhs(), broadcastedShape);
op.getLoc(), resultType, op.rhs(), broadcastedExtents);
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
lhsBroadcasted, rhsBroadcasted);
rewriter.replaceOp(op, add);

View File

@ -201,16 +201,11 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp op,
PatternRewriter &rewriter) const override {
// TODO: Remove this const pattern when lowering to shape.get_extent.
auto constIndex = op.getConstantIndex();
if (!constIndex)
return failure();
auto allocMemRef = op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
if (!allocMemRef)
return rewriter.notifyMatchFailure(op, "could not find alloc_memref");
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, allocMemRef.shape(),
*constIndex);
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(
op, rewriter.getIndexType(), allocMemRef.shape(), op.index());
return success();
}
};
@ -231,7 +226,7 @@ class LowerLinalgLoopDimOps
// remove this.
return !op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
});
target.addLegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::GetExtentOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
@ -261,7 +256,8 @@ public:
SmallVector<Value, 6> dynamicExtents;
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i)) {
auto extent = rewriter.create<tcp::GetExtentOp>(op.getLoc(), shape, i);
auto extent =
rewriter.create<shape::GetExtentOp>(op.getLoc(), shape, i);
dynamicExtents.push_back(extent);
}
}
@ -281,8 +277,9 @@ class LowerAllocMemRefOps
patterns.insert<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<tcp::AllocMemRefOp>();
target.addLegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::GetExtentOp>();
target.addLegalOp<AllocOp>();
target.addLegalOp<ConstantOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
@ -433,8 +430,11 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
pm.addPass(createLowerRankedShapesPass());
// Run a some cleanups.
// TODO: Some folding and DCE of dangling ops is still needed here. Once the
// invariants above are tightened up, the canonicalize should be moved into
// the optimize block.
pm.addPass(createCanonicalizerPass());
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}

View File

@ -37,6 +37,21 @@ public:
} // namespace
namespace {
// Given an operand that is either a Shape or Extent Tensor, returns an
// Extent Tensor or nullptr if this cannot be locally determined.
// The return value, if !nullptr, will be a 1D RankedTensorType (with possibly
// unknown element).
Value findExtentsFromShape(Value operand, bool requireKnownRank) {
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getRank() == 1 &&
(!requireKnownRank || tensorType.hasStaticShape())) {
return operand;
}
}
return nullptr;
}
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -44,14 +59,21 @@ public:
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::BroadcastOp::Adaptor adaptor(operands);
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
if (!lhs || !rhs)
return rewriter.notifyMatchFailure(op, "operands not converted");
// When the ranks are statically known, generate non-branchy code.
// TODO: Generate rank-generic code.
auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true);
auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true);
if (!lhsExtents || !rhsExtents)
return rewriter.notifyMatchFailure(op, "dynamic extents not supported");
// Establish invariant that rank(lhs) >= rank(rhs).
if (lhs.extents().size() < rhs.extents().size())
std::swap(lhs, rhs);
auto rankDiscrepancy = lhs.extents().size() - rhs.extents().size();
auto lhsSize = lhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
auto rhsSize = rhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
if (lhsSize < rhsSize) {
std::swap(lhsExtents, rhsExtents);
std::swap(lhsSize, rhsSize);
}
auto rankDiscrepancy = lhsSize - rhsSize;
// Helper that creates IR
// ```
@ -72,19 +94,31 @@ public:
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
};
auto resultExtents = llvm::to_vector<6>(lhs.extents());
for (int i = 0, e = rhs.extents().size(); i < e; i++) {
auto lhsExtent = lhs.extents()[rankDiscrepancy + i];
auto rhsExtent = rhs.extents()[i];
SmallVector<Value, 6> resultExtents;
for (int i = 0, e = lhsSize; i < e; i++) {
auto lhsDim = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
auto lhsExtent = rewriter.create<ExtractElementOp>(
op.getLoc(), lhsExtents, ValueRange{lhsDim});
if (i < rankDiscrepancy) {
// Padded extent.
resultExtents.push_back(lhsExtent);
continue;
}
// Non-padded extent.
auto rhsDim =
rewriter.create<ConstantIndexOp>(op.getLoc(), i - rankDiscrepancy);
auto rhsExtent = rewriter.create<ExtractElementOp>(
op.getLoc(), rhsExtents, ValueRange{rhsDim});
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
lhsExtent, rhsExtent);
auto max =
auto resultExtent =
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
auto &resultExtent = resultExtents[rankDiscrepancy + i];
resultExtent = max;
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
resultExtents.push_back(resultExtent);
}
// TODO: Remove the return type once ODS is fixed to do proper inference.
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
@ -93,26 +127,44 @@ public:
};
} // namespace
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
//
// TODO: this should be a fold on tcp::GetExtentOp.
// (though then the contract of this pass depends on that set of folds,
// which isn't great)
//
// Also, we use OpConversionPattern to get post-rewrite operands as above.
namespace {
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
class LowerShapeToExtentTensorOp
: public OpConversionPattern<shape::ToExtentTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tcp::GetExtentOp::Adaptor adaptor(operands);
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents)
return rewriter.notifyMatchFailure(op, "not a from_extents op");
int64_t dim = op.dim().getLimitedValue();
rewriter.replaceOp(op, ValueRange(fromExtents.extents())[dim]);
shape::ToExtentTensorOpAdaptor adaptor(operands);
if (adaptor.input().getType().isa<shape::ShapeType>()) {
// Convert by matching to a producing FromExtentsOp.
auto fromExtents = adaptor.input().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents) {
return rewriter.notifyMatchFailure(op, "not a from_extents op");
}
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op,
fromExtents.extents());
return success();
}
// Assume that it is already an extent tensor.
// TODO: Since these ops are all multi-type, there should be a utility
// for switching on the allowable types instead of just assuming that it
// is an extent tensor.
rewriter.replaceOp(op, adaptor.input());
return success();
}
};
class LowerShapeGetExtentOp : public OpConversionPattern<shape::GetExtentOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::GetExtentOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, adaptor.shape(),
adaptor.dim());
return success();
}
};
@ -136,20 +188,13 @@ public:
} // namespace
// Basic invariant of this pass:
// Every def of a !shape.shape type is replaced with a
// `shape.from_extents` op.
// When converting an op, look for the `shape.from_extents` op that
// defined all operands, then do a computation on the extents (i.e.
// operands to the `shape.from_extents` op) and produce a
// `shape.from_extents` op.
// Every `shape.from_extents` op operating on an extent tensor
// (`tensor<?xindex>`) is replaced by corresponding standard ops and folded
// away (for the ranked case, it should be possible to eliminate these).
//
// We expect that previous passes have inserted a "root" set of
// shape::FromExtentsOp's that allow this process to get started.
//
// We then use this to resolve get_extent ops by using a rewrite
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
// maximally many places due to the above invariant.
//
// This is similar to the approach that is used in IREE. It is basically a
// combination of the ConvertShapeToShapex pass and the
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
@ -167,6 +212,11 @@ public:
// ranks of use-def cycles ahead of time or optimistically assume that
// backedges will match the rank of forward edges, and somehow be robust
// when that assumption fails.
//
// TODO: Add in a fold of
// `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore
// the above invariant without relying on a subsequent canonicalization
// step.
namespace {
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
void runOnOperation() {
@ -177,12 +227,14 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
patterns.insert<LowerConstShapeOp>(context);
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<LowerShapeToExtentTensorOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<tcp::GetExtentOp>();
target.addIllegalOp<shape::GetExtentOp>();
target.addLegalOp<shape::FromExtentsOp>();
target.addIllegalOp<shape::ToExtentTensorOp>();
target.addLegalOp<npcomprt::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
@ -196,11 +248,12 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
auto walkResult = func.walk([](Operation *op) {
if (!isa<shape::FromExtentsOp>(op))
return WalkResult::advance();
if (!op->use_empty()) {
if (op->use_empty()) {
op->erase();
} else {
op->emitError("could not be eliminated");
return WalkResult::interrupt();
}
op->erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())

View File

@ -62,9 +62,9 @@ public:
// TODO: handle output rank > input rank.
for (int i = 0, e = resultType.getRank(); i < e; i++) {
Value outputExtent = rewriter.create<tcp::GetExtentOp>(
op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i));
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
Value outputExtent = rewriter.create<shape::GetExtentOp>(
op.getLoc(), rewriter.getIndexType(), op.shape(), dimIndex);
outputExtents.push_back(outputExtent);
}
int rankDiff = resultType.getRank() - inputType.getRank();
@ -127,12 +127,10 @@ public:
LogicalResult matchAndRewrite(DimOp op,
PatternRewriter &rewriter) const override {
// TODO: Remove this const pattern when lowering to shape.get_extent.
auto constIndex = op.getConstantIndex();
if (!constIndex)
return failure();
auto shape =
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.memrefOrTensor());
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, shape, *constIndex);
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, rewriter.getIndexType(),
shape, op.index());
return success();
}
};

View File

@ -132,8 +132,14 @@ public:
extents.push_back(rewriter.create<npcomprt::GetExtentOp>(
op.getLoc(), rewriter.getIndexType(), adaptor.arg(), ci));
}
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, rewriter.getType<shape::ShapeType>(), extents);
auto newShape = rewriter.create<shape::FromExtentsOp>(
op.getLoc(), rewriter.getType<shape::ShapeType>(), extents);
// TODO: Provide a builder that doesn't require the result type.
rewriter.replaceOpWithNewOp<shape::ToExtentTensorOp>(
op,
RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType()),
newShape);
return success();
}
};
@ -203,6 +209,7 @@ static LogicalResult doDialectConversion(ModuleOp module) {
target.addIllegalOp<shape::ShapeOfOp>();
target.addLegalOp<ConstantOp>();
target.addLegalOp<shape::FromExtentsOp>();
target.addLegalOp<shape::ToExtentTensorOp>();
target.addLegalOp<npcomprt::GetExtentOp>();
patterns.insert<LowerGlobalOp>(context);

View File

@ -9,3 +9,11 @@ func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @multiple_ops
func @multiple_ops(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tcf.add"(%arg1, %arg2) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

View File

@ -0,0 +1,12 @@
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline{optimize} | FileCheck %s --dump-input=fail
// -----
// CHECK-LABEL: func @global_add
func @global_add() -> tensor<2xf32> attributes {iree.module.export} {
%cst = constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
%cst_0 = constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
%0 = "tcf.add"(%cst, %cst_0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%1 = "tcf.add"(%cst_0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}

View File

@ -1,26 +1,31 @@
// RUN: npcomp-opt -lower-alloc-memref-ops <%s | FileCheck %s
// RUN: npcomp-opt -split-input-file -lower-alloc-memref-ops <%s | FileCheck %s
// CHECK-LABEL: func @basic
func @basic(%arg0: !shape.shape) {
// CHECK: %[[E:.*]] = tcp.get_extent %arg0, 0
func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
// CHECK: %[[I:.*]] = constant 0 : index
// CHECK: %[[E:.*]] = shape.get_extent %arg0, %[[I]]
// CHECK: alloc(%[[E]])
%0 = tcp.alloc_memref %arg0 : memref<?xf32>
return
return %0 : memref<?xf32>
}
// CHECK: func @all_static(%arg0: !shape.shape)
func @all_static(%arg0: !shape.shape) {
// CHECK-NOT: tcp.get_extent
// -----
// CHECK-LABEL: func @all_static
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
// CHECK-NOT: shape.get_extent
// CHECK: alloc()
%0 = tcp.alloc_memref %arg0 : memref<3x4x5xf32>
return
return %0 : memref<3x4x5xf32>
}
// CHECK: func @some_static(%arg0: !shape.shape)
func @some_static(%arg0: !shape.shape) {
// CHECK: %[[E1:.*]] = tcp.get_extent %arg0, 1
// CHECK: %[[E3:.*]] = tcp.get_extent %arg0, 3
// -----
// CHECK-LABEL: func @some_static
func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
// CHECK-DAG: %[[I1:.*]] = constant 1 : index
// CHECK-DAG: %[[E1:.*]] = shape.get_extent %arg0, %[[I1]]
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
// CHECK-DAG: %[[E3:.*]] = shape.get_extent %arg0, %[[I3]]
// CHECK: alloc(%[[E1]], %[[E3]])
%0 = tcp.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
return
return %0 : memref<3x?x5x?x7xf32>
}

View File

@ -1,19 +1,24 @@
// RUN: npcomp-opt -lower-ranked-shapes <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @broadcast_rank2_rank1
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
// CHECK-NOT: shape.broadcast
// CHECK-NOT: tcp.get_extent
// CHECK-NOT: shape.from_extents
%0 = shape.from_extents %arg0, %arg1
%1 = shape.from_extents %arg2
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%e0 = tcp.get_extent %2, 0
%e1 = tcp.get_extent %2, 1
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
%2 = shape.from_extents %arg2
%3 = shape.to_extent_tensor %2 : !shape.shape -> tensor<?xindex>
%4 = "shape.broadcast"(%1, %3) : (tensor<?xindex>, tensor<?xindex>) -> !shape.shape
%5 = shape.to_extent_tensor %4 : !shape.shape -> tensor<?xindex>
%c0 = constant 0 : index
%c1 = constant 1 : index
%e0 = shape.get_extent %5, %c0 : tensor<?xindex>, index -> index
%e1 = shape.get_extent %5, %c1 : tensor<?xindex>, index -> index
return %e0, %e1 : index, index
}
// -----
// CHECK-LABEL: func @erase_stray_shape_ops
func @erase_stray_shape_ops(%arg0: index) {
// CHECK-NOT: tcp.shape_observe_error
@ -24,7 +29,6 @@ func @erase_stray_shape_ops(%arg0: index) {
}
// -----
func @cannot_erase_stray_shape_ops() -> !shape.shape {
// expected-error @+1 {{could not be eliminated}}
%0 = shape.from_extents
@ -32,14 +36,15 @@ func @cannot_erase_stray_shape_ops() -> !shape.shape {
}
// -----
// TODO: Remove this as it is now just testing shape and std ops.
// CHECK-LABEL: func @const_shape
func @const_shape() -> index {
// CHECK-NOT: shape.const_shape
%0 = shape.const_shape []
%1 = shape.const_shape [7]
%2 = tcp.get_extent %1, 0
%0 = shape.const_shape [] : tensor<?xindex>
%1 = shape.const_shape [7] : tensor<?xindex>
%2 = constant 0 : index
%3 = shape.get_extent %1, %2 : tensor<?xindex>, index -> index
// CHECK: %[[C7:.*]] = constant 7 : index
// CHECK: return %[[C7]]
return %2 : index
return %3 : index
}

View File

@ -18,17 +18,18 @@ func @identity(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[VAL_1:.*]] = constant 0 : i32
// CHECK: %[[VAL_2:.*]] = npcomprt.get_extent %[[VAL_0]], %[[VAL_1]]
// CHECK: %[[VAL_3:.*]] = shape.from_extents %[[VAL_2]]
// CHECK: %[[VAL_4:.*]] = tcp.alloc_memref %[[VAL_3]] : memref<?xf32>
// CHECK: %[[VAL_5:.*]] = npcomprt.to_memref %[[VAL_0]] : memref<*xf32>
// CHECK: %[[VAL_6:.*]] = memref_cast %[[VAL_5]] : memref<*xf32> to memref<?xf32>
// CHECK: linalg.copy(%[[VAL_6]], %[[VAL_4]]) : memref<?xf32>, memref<?xf32>
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_4]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[VAL_8:.*]] = npcomprt.from_memref %[[VAL_7]] : memref<*xf32>
// CHECK: return %[[VAL_8]] : !npcomprt.tensor
// CHECK: %[[VAL_4:.*]] = shape.to_extent_tensor %[[VAL_3]]
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
// CHECK: %[[VAL_6:.*]] = npcomprt.to_memref %[[VAL_0]] : memref<*xf32>
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_6]] : memref<*xf32> to memref<?xf32>
// CHECK: linalg.copy(%[[VAL_7]], %[[VAL_5]]) : memref<?xf32>, memref<?xf32>
// CHECK: %[[VAL_8:.*]] = memref_cast %[[VAL_5]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[VAL_9:.*]] = npcomprt.from_memref %[[VAL_8]] : memref<*xf32>
// CHECK: return %[[VAL_9]] : !npcomprt.tensor
// CHECK: }
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%shape = shape.shape_of %arg0 : tensor<?xf32>
%shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
%memref = tcp.alloc_memref %shape : memref<?xf32>
tensor_store %arg0, %memref : memref<?xf32>
%ret = tensor_load %memref : memref<?xf32>
@ -40,7 +41,7 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
tcp.global @g dense<7.0> : tensor<10xf32>
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
func @gets_global() -> tensor<10xf32> {
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>

View File

@ -1,20 +1,20 @@
// RUN: npcomp-opt -resolve-shape-of-ops <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @basic
func @basic(%arg0: !shape.shape) -> !shape.shape {
func @basic(%arg0: tensor<?xindex>) -> tensor<?xindex> {
%memref = tcp.alloc_memref %arg0 : memref<?xf32>
%tensor = tensor_load %memref : memref<?xf32>
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> !shape.shape
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> tensor<?xindex>
// CHECK: return %arg0
return %shape : !shape.shape
return %shape : tensor<?xindex>
}
// -----
// CHECK-LABEL: func @arg_unresolved_ok
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> !shape.shape {
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> !shape.shape
return %0 : !shape.shape
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> tensor<?xindex> {
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> tensor<?xindex>
return %0 : tensor<?xindex>
}
// -----
@ -22,9 +22,9 @@ func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> !shape.shape {
// CHECK-LABEL: func @TODO_bb_arg_unresolved_not_ok
// TODO: This should emit a diagnostic, but doesn't. Why?
// addDynamicallyLegalOp isn't working as I expect.
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> !shape.shape {
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xindex> {
cond_br %arg0, ^bb1(%arg1: tensor<?xf32>), ^bb1(%arg2: tensor<?xf32>)
^bb1(%bbarg: tensor<?xf32>):
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> !shape.shape
return %0 : !shape.shape
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> tensor<?xindex>
return %0 : tensor<?xindex>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @basic
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> tensor<?xindex>
// CHECK: %[[SRCMEMREF:.+]] = tcp.alloc_memref
%src_memref = tcp.alloc_memref %shape : memref<?xf32>

View File

@ -2,10 +2,10 @@
// RUN: -invoke scalar \
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// RUN: | FileCheck %s
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tcf.add"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}