mirror of https://github.com/llvm/torch-mlir
Rework reference shape lowering based on upstream shape dialect changes.
* Primarily, the upstream shape dialect now uses tensor<?xindex> for non-erroring, immediate shape calculations (and will return this for shape_of of a tensor or memref). * In addition, upstream passes do not yet exist for fully lowering to standard ops, so the passes here need to be extended to handle this new convention. * This should be seen as an intermediate state, necessary to integrate a new LLVM version and needs more work and cleanup for generality. * There is a good deal of awkwardness in these conversions. The hope is that additional upstream work will yield better defined conversion paths once out of this intermediate state.pull/2/head
parent
624d4c6c50
commit
fc484d1bd8
|
@ -39,7 +39,7 @@ Broadcasts `operand` to the shape `shape`.
|
|||
|
||||
It is undefined behavior if such a broadcast is not legal.
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$operand, Shape_ShapeType:$shape);
|
||||
let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$shape);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
|
|||
let description = [{
|
||||
Allocates a memref of the given shape.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape);
|
||||
let arguments = (ins Shape_ExtentTensorType:$shape);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
@ -102,36 +102,10 @@ def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error", []> {
|
|||
effecting ops, is not very well-defined, and needs to be worked
|
||||
on/redesigned.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape);
|
||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
|
||||
// TODO: ODS seems to create redeclared class members if we remove this,
|
||||
// resulting in C++ compilation errors.
|
||||
let results = (outs NoneType:$dummy);
|
||||
}
|
||||
|
||||
// TODO: This probably belongs in the shape dialect.
|
||||
def TCP_GetExtentOp : TCP_Op<"get_extent", [NoSideEffect]> {
|
||||
let summary = "Gets the specified extent from a shape.";
|
||||
let description = [{
|
||||
Gets the specified extent from a shape.
|
||||
|
||||
This op has undefined behavior if the shape is an error.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape, I64Attr:$dim);
|
||||
let results = (outs Index:$extent);
|
||||
let assemblyFormat = "$shape `,` $dim attr-dict";
|
||||
|
||||
let builders = [
|
||||
// Helper to pass a simple integer instead of an integer attr.
|
||||
OpBuilder<
|
||||
[{
|
||||
OpBuilder &builder, OperationState &result,
|
||||
Value shape, int64_t dim
|
||||
}],
|
||||
[{
|
||||
build(builder, result, shape, builder.getI64IntegerAttr(dim));
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TCP_OPS
|
||||
|
|
|
@ -19,6 +19,12 @@ using namespace mlir;
|
|||
using namespace mlir::NPCOMP;
|
||||
|
||||
namespace {
|
||||
|
||||
RankedTensorType getExtentTensorType(Builder &builder) {
|
||||
return RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
builder.getIndexType());
|
||||
}
|
||||
|
||||
class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -34,6 +40,9 @@ public:
|
|||
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
||||
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
|
||||
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
|
||||
Value broadcastedExtents = rewriter.create<shape::ToExtentTensorOp>(
|
||||
op.getLoc(), getExtentTensorType(rewriter), broadcastedShape);
|
||||
|
||||
// TODO: It's annoying to do the dynamic broadcast above then
|
||||
// do the static transfer function here. Would be nice if they could
|
||||
// somehow be unified.
|
||||
|
@ -43,9 +52,9 @@ public:
|
|||
auto resultType =
|
||||
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
|
||||
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||
op.getLoc(), resultType, op.lhs(), broadcastedShape);
|
||||
op.getLoc(), resultType, op.lhs(), broadcastedExtents);
|
||||
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
|
||||
op.getLoc(), resultType, op.rhs(), broadcastedShape);
|
||||
op.getLoc(), resultType, op.rhs(), broadcastedExtents);
|
||||
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
|
||||
lhsBroadcasted, rhsBroadcasted);
|
||||
rewriter.replaceOp(op, add);
|
||||
|
|
|
@ -201,16 +201,11 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: Remove this const pattern when lowering to shape.get_extent.
|
||||
auto constIndex = op.getConstantIndex();
|
||||
if (!constIndex)
|
||||
return failure();
|
||||
|
||||
auto allocMemRef = op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
|
||||
if (!allocMemRef)
|
||||
return rewriter.notifyMatchFailure(op, "could not find alloc_memref");
|
||||
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, allocMemRef.shape(),
|
||||
*constIndex);
|
||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(
|
||||
op, rewriter.getIndexType(), allocMemRef.shape(), op.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -231,7 +226,7 @@ class LowerLinalgLoopDimOps
|
|||
// remove this.
|
||||
return !op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
|
||||
});
|
||||
target.addLegalOp<tcp::GetExtentOp>();
|
||||
target.addLegalOp<shape::GetExtentOp>();
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
@ -261,7 +256,8 @@ public:
|
|||
SmallVector<Value, 6> dynamicExtents;
|
||||
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
|
||||
if (memrefType.isDynamicDim(i)) {
|
||||
auto extent = rewriter.create<tcp::GetExtentOp>(op.getLoc(), shape, i);
|
||||
auto extent =
|
||||
rewriter.create<shape::GetExtentOp>(op.getLoc(), shape, i);
|
||||
dynamicExtents.push_back(extent);
|
||||
}
|
||||
}
|
||||
|
@ -281,8 +277,9 @@ class LowerAllocMemRefOps
|
|||
patterns.insert<LowerAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<tcp::AllocMemRefOp>();
|
||||
target.addLegalOp<tcp::GetExtentOp>();
|
||||
target.addLegalOp<shape::GetExtentOp>();
|
||||
target.addLegalOp<AllocOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
@ -433,8 +430,11 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
|||
pm.addPass(createLowerRankedShapesPass());
|
||||
|
||||
// Run a some cleanups.
|
||||
// TODO: Some folding and DCE of dangling ops is still needed here. Once the
|
||||
// invariants above are tightened up, the canonicalize should be moved into
|
||||
// the optimize block.
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
if (options.optimize) {
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,21 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
// Given an operand that is either a Shape or Extent Tensor, returns an
|
||||
// Extent Tensor or nullptr if this cannot be locally determined.
|
||||
// The return value, if !nullptr, will be a 1D RankedTensorType (with possibly
|
||||
// unknown element).
|
||||
Value findExtentsFromShape(Value operand, bool requireKnownRank) {
|
||||
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getRank() == 1 &&
|
||||
(!requireKnownRank || tensorType.hasStaticShape())) {
|
||||
return operand;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
@ -44,14 +59,21 @@ public:
|
|||
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::BroadcastOp::Adaptor adaptor(operands);
|
||||
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
|
||||
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
|
||||
if (!lhs || !rhs)
|
||||
return rewriter.notifyMatchFailure(op, "operands not converted");
|
||||
// When the ranks are statically known, generate non-branchy code.
|
||||
// TODO: Generate rank-generic code.
|
||||
auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true);
|
||||
auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true);
|
||||
if (!lhsExtents || !rhsExtents)
|
||||
return rewriter.notifyMatchFailure(op, "dynamic extents not supported");
|
||||
|
||||
// Establish invariant that rank(lhs) >= rank(rhs).
|
||||
if (lhs.extents().size() < rhs.extents().size())
|
||||
std::swap(lhs, rhs);
|
||||
auto rankDiscrepancy = lhs.extents().size() - rhs.extents().size();
|
||||
auto lhsSize = lhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
||||
auto rhsSize = rhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
||||
if (lhsSize < rhsSize) {
|
||||
std::swap(lhsExtents, rhsExtents);
|
||||
std::swap(lhsSize, rhsSize);
|
||||
}
|
||||
auto rankDiscrepancy = lhsSize - rhsSize;
|
||||
|
||||
// Helper that creates IR
|
||||
// ```
|
||||
|
@ -72,19 +94,31 @@ public:
|
|||
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
|
||||
};
|
||||
|
||||
auto resultExtents = llvm::to_vector<6>(lhs.extents());
|
||||
for (int i = 0, e = rhs.extents().size(); i < e; i++) {
|
||||
auto lhsExtent = lhs.extents()[rankDiscrepancy + i];
|
||||
auto rhsExtent = rhs.extents()[i];
|
||||
SmallVector<Value, 6> resultExtents;
|
||||
for (int i = 0, e = lhsSize; i < e; i++) {
|
||||
auto lhsDim = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||
auto lhsExtent = rewriter.create<ExtractElementOp>(
|
||||
op.getLoc(), lhsExtents, ValueRange{lhsDim});
|
||||
if (i < rankDiscrepancy) {
|
||||
// Padded extent.
|
||||
resultExtents.push_back(lhsExtent);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Non-padded extent.
|
||||
auto rhsDim =
|
||||
rewriter.create<ConstantIndexOp>(op.getLoc(), i - rankDiscrepancy);
|
||||
auto rhsExtent = rewriter.create<ExtractElementOp>(
|
||||
op.getLoc(), rhsExtents, ValueRange{rhsDim});
|
||||
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
|
||||
lhsExtent, rhsExtent);
|
||||
auto max =
|
||||
auto resultExtent =
|
||||
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
|
||||
auto &resultExtent = resultExtents[rankDiscrepancy + i];
|
||||
resultExtent = max;
|
||||
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
||||
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
||||
resultExtents.push_back(resultExtent);
|
||||
}
|
||||
|
||||
// TODO: Remove the return type once ODS is fixed to do proper inference.
|
||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
||||
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
|
||||
|
@ -93,26 +127,44 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
|
||||
//
|
||||
// TODO: this should be a fold on tcp::GetExtentOp.
|
||||
// (though then the contract of this pass depends on that set of folds,
|
||||
// which isn't great)
|
||||
//
|
||||
// Also, we use OpConversionPattern to get post-rewrite operands as above.
|
||||
namespace {
|
||||
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
|
||||
class LowerShapeToExtentTensorOp
|
||||
: public OpConversionPattern<shape::ToExtentTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tcp::GetExtentOp::Adaptor adaptor(operands);
|
||||
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
|
||||
if (!fromExtents)
|
||||
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
||||
int64_t dim = op.dim().getLimitedValue();
|
||||
rewriter.replaceOp(op, ValueRange(fromExtents.extents())[dim]);
|
||||
shape::ToExtentTensorOpAdaptor adaptor(operands);
|
||||
if (adaptor.input().getType().isa<shape::ShapeType>()) {
|
||||
// Convert by matching to a producing FromExtentsOp.
|
||||
auto fromExtents = adaptor.input().getDefiningOp<shape::FromExtentsOp>();
|
||||
if (!fromExtents) {
|
||||
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op,
|
||||
fromExtents.extents());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Assume that it is already an extent tensor.
|
||||
// TODO: Since these ops are all multi-type, there should be a utility
|
||||
// for switching on the allowable types instead of just assuming that it
|
||||
// is an extent tensor.
|
||||
rewriter.replaceOp(op, adaptor.input());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class LowerShapeGetExtentOp : public OpConversionPattern<shape::GetExtentOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::GetExtentOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::GetExtentOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, adaptor.shape(),
|
||||
adaptor.dim());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -136,20 +188,13 @@ public:
|
|||
} // namespace
|
||||
|
||||
// Basic invariant of this pass:
|
||||
// Every def of a !shape.shape type is replaced with a
|
||||
// `shape.from_extents` op.
|
||||
// When converting an op, look for the `shape.from_extents` op that
|
||||
// defined all operands, then do a computation on the extents (i.e.
|
||||
// operands to the `shape.from_extents` op) and produce a
|
||||
// `shape.from_extents` op.
|
||||
// Every `shape.from_extents` op operating on an extent tensor
|
||||
// (`tensor<?xindex>`) is replaced by corresponding standard ops and folded
|
||||
// away (for the ranked case, it should be possible to eliminate these).
|
||||
//
|
||||
// We expect that previous passes have inserted a "root" set of
|
||||
// shape::FromExtentsOp's that allow this process to get started.
|
||||
//
|
||||
// We then use this to resolve get_extent ops by using a rewrite
|
||||
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
|
||||
// maximally many places due to the above invariant.
|
||||
//
|
||||
// This is similar to the approach that is used in IREE. It is basically a
|
||||
// combination of the ConvertShapeToShapex pass and the
|
||||
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
|
||||
|
@ -167,6 +212,11 @@ public:
|
|||
// ranks of use-def cycles ahead of time or optimistically assume that
|
||||
// backedges will match the rank of forward edges, and somehow be robust
|
||||
// when that assumption fails.
|
||||
//
|
||||
// TODO: Add in a fold of
|
||||
// `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore
|
||||
// the above invariant without relying on a subsequent canonicalization
|
||||
// step.
|
||||
namespace {
|
||||
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||
void runOnOperation() {
|
||||
|
@ -177,12 +227,14 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
patterns.insert<LowerConstShapeOp>(context);
|
||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
||||
patterns.insert<LowerShapeToExtentTensorOp>(context);
|
||||
patterns.insert<EraseShapeObserveErrorOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addIllegalOp<shape::BroadcastOp>();
|
||||
target.addIllegalOp<tcp::GetExtentOp>();
|
||||
target.addIllegalOp<shape::GetExtentOp>();
|
||||
target.addLegalOp<shape::FromExtentsOp>();
|
||||
target.addIllegalOp<shape::ToExtentTensorOp>();
|
||||
target.addLegalOp<npcomprt::AbortIfOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
||||
|
@ -196,11 +248,12 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
auto walkResult = func.walk([](Operation *op) {
|
||||
if (!isa<shape::FromExtentsOp>(op))
|
||||
return WalkResult::advance();
|
||||
if (!op->use_empty()) {
|
||||
if (op->use_empty()) {
|
||||
op->erase();
|
||||
} else {
|
||||
op->emitError("could not be eliminated");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
op->erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted())
|
||||
|
|
|
@ -62,9 +62,9 @@ public:
|
|||
|
||||
// TODO: handle output rank > input rank.
|
||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||
|
||||
Value outputExtent = rewriter.create<tcp::GetExtentOp>(
|
||||
op.getLoc(), op.shape(), rewriter.getI64IntegerAttr(i));
|
||||
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||
Value outputExtent = rewriter.create<shape::GetExtentOp>(
|
||||
op.getLoc(), rewriter.getIndexType(), op.shape(), dimIndex);
|
||||
outputExtents.push_back(outputExtent);
|
||||
}
|
||||
int rankDiff = resultType.getRank() - inputType.getRank();
|
||||
|
@ -127,12 +127,10 @@ public:
|
|||
LogicalResult matchAndRewrite(DimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: Remove this const pattern when lowering to shape.get_extent.
|
||||
auto constIndex = op.getConstantIndex();
|
||||
if (!constIndex)
|
||||
return failure();
|
||||
auto shape =
|
||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.memrefOrTensor());
|
||||
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, shape, *constIndex);
|
||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, rewriter.getIndexType(),
|
||||
shape, op.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -132,8 +132,14 @@ public:
|
|||
extents.push_back(rewriter.create<npcomprt::GetExtentOp>(
|
||||
op.getLoc(), rewriter.getIndexType(), adaptor.arg(), ci));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
||||
op, rewriter.getType<shape::ShapeType>(), extents);
|
||||
auto newShape = rewriter.create<shape::FromExtentsOp>(
|
||||
op.getLoc(), rewriter.getType<shape::ShapeType>(), extents);
|
||||
// TODO: Provide a builder that doesn't require the result type.
|
||||
rewriter.replaceOpWithNewOp<shape::ToExtentTensorOp>(
|
||||
op,
|
||||
RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType()),
|
||||
newShape);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -203,6 +209,7 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
|||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addLegalOp<shape::FromExtentsOp>();
|
||||
target.addLegalOp<shape::ToExtentTensorOp>();
|
||||
target.addLegalOp<npcomprt::GetExtentOp>();
|
||||
|
||||
patterns.insert<LowerGlobalOp>(context);
|
||||
|
|
|
@ -9,3 +9,11 @@ func @rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
%0 = "tcf.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @multiple_ops
|
||||
func @multiple_ops(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = "tcf.add"(%arg1, %arg2) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline | FileCheck %s --dump-input=fail
|
||||
// RUN: npcomp-opt <%s -pass-pipeline=e2e-lowering-pipeline{optimize} | FileCheck %s --dump-input=fail
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @global_add
|
||||
func @global_add() -> tensor<2xf32> attributes {iree.module.export} {
|
||||
%cst = constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
%cst_0 = constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>
|
||||
%0 = "tcf.add"(%cst, %cst_0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%1 = "tcf.add"(%cst_0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
|
@ -1,26 +1,31 @@
|
|||
// RUN: npcomp-opt -lower-alloc-memref-ops <%s | FileCheck %s
|
||||
// RUN: npcomp-opt -split-input-file -lower-alloc-memref-ops <%s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: !shape.shape) {
|
||||
// CHECK: %[[E:.*]] = tcp.get_extent %arg0, 0
|
||||
func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
|
||||
// CHECK: %[[I:.*]] = constant 0 : index
|
||||
// CHECK: %[[E:.*]] = shape.get_extent %arg0, %[[I]]
|
||||
// CHECK: alloc(%[[E]])
|
||||
%0 = tcp.alloc_memref %arg0 : memref<?xf32>
|
||||
return
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @all_static(%arg0: !shape.shape)
|
||||
func @all_static(%arg0: !shape.shape) {
|
||||
// CHECK-NOT: tcp.get_extent
|
||||
// -----
|
||||
// CHECK-LABEL: func @all_static
|
||||
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
|
||||
// CHECK-NOT: shape.get_extent
|
||||
// CHECK: alloc()
|
||||
%0 = tcp.alloc_memref %arg0 : memref<3x4x5xf32>
|
||||
return
|
||||
return %0 : memref<3x4x5xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @some_static(%arg0: !shape.shape)
|
||||
func @some_static(%arg0: !shape.shape) {
|
||||
// CHECK: %[[E1:.*]] = tcp.get_extent %arg0, 1
|
||||
// CHECK: %[[E3:.*]] = tcp.get_extent %arg0, 3
|
||||
// -----
|
||||
// CHECK-LABEL: func @some_static
|
||||
func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
|
||||
// CHECK-DAG: %[[I1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[E1:.*]] = shape.get_extent %arg0, %[[I1]]
|
||||
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[E3:.*]] = shape.get_extent %arg0, %[[I3]]
|
||||
// CHECK: alloc(%[[E1]], %[[E3]])
|
||||
%0 = tcp.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
|
||||
return
|
||||
return %0 : memref<3x?x5x?x7xf32>
|
||||
}
|
||||
|
|
|
@ -1,19 +1,24 @@
|
|||
// RUN: npcomp-opt -lower-ranked-shapes <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// CHECK-LABEL: func @broadcast_rank2_rank1
|
||||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
||||
// CHECK-NOT: shape.broadcast
|
||||
// CHECK-NOT: tcp.get_extent
|
||||
// CHECK-NOT: shape.from_extents
|
||||
%0 = shape.from_extents %arg0, %arg1
|
||||
%1 = shape.from_extents %arg2
|
||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%e0 = tcp.get_extent %2, 0
|
||||
%e1 = tcp.get_extent %2, 1
|
||||
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
|
||||
%2 = shape.from_extents %arg2
|
||||
%3 = shape.to_extent_tensor %2 : !shape.shape -> tensor<?xindex>
|
||||
%4 = "shape.broadcast"(%1, %3) : (tensor<?xindex>, tensor<?xindex>) -> !shape.shape
|
||||
%5 = shape.to_extent_tensor %4 : !shape.shape -> tensor<?xindex>
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%e0 = shape.get_extent %5, %c0 : tensor<?xindex>, index -> index
|
||||
%e1 = shape.get_extent %5, %c1 : tensor<?xindex>, index -> index
|
||||
return %e0, %e1 : index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @erase_stray_shape_ops
|
||||
func @erase_stray_shape_ops(%arg0: index) {
|
||||
// CHECK-NOT: tcp.shape_observe_error
|
||||
|
@ -24,7 +29,6 @@ func @erase_stray_shape_ops(%arg0: index) {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cannot_erase_stray_shape_ops() -> !shape.shape {
|
||||
// expected-error @+1 {{could not be eliminated}}
|
||||
%0 = shape.from_extents
|
||||
|
@ -32,14 +36,15 @@ func @cannot_erase_stray_shape_ops() -> !shape.shape {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Remove this as it is now just testing shape and std ops.
|
||||
// CHECK-LABEL: func @const_shape
|
||||
func @const_shape() -> index {
|
||||
// CHECK-NOT: shape.const_shape
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [7]
|
||||
%2 = tcp.get_extent %1, 0
|
||||
%0 = shape.const_shape [] : tensor<?xindex>
|
||||
%1 = shape.const_shape [7] : tensor<?xindex>
|
||||
%2 = constant 0 : index
|
||||
%3 = shape.get_extent %1, %2 : tensor<?xindex>, index -> index
|
||||
// CHECK: %[[C7:.*]] = constant 7 : index
|
||||
// CHECK: return %[[C7]]
|
||||
return %2 : index
|
||||
return %3 : index
|
||||
}
|
||||
|
|
|
@ -18,17 +18,18 @@ func @identity(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
// CHECK: %[[VAL_1:.*]] = constant 0 : i32
|
||||
// CHECK: %[[VAL_2:.*]] = npcomprt.get_extent %[[VAL_0]], %[[VAL_1]]
|
||||
// CHECK: %[[VAL_3:.*]] = shape.from_extents %[[VAL_2]]
|
||||
// CHECK: %[[VAL_4:.*]] = tcp.alloc_memref %[[VAL_3]] : memref<?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = npcomprt.to_memref %[[VAL_0]] : memref<*xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = memref_cast %[[VAL_5]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: linalg.copy(%[[VAL_6]], %[[VAL_4]]) : memref<?xf32>, memref<?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_4]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = npcomprt.from_memref %[[VAL_7]] : memref<*xf32>
|
||||
// CHECK: return %[[VAL_8]] : !npcomprt.tensor
|
||||
// CHECK: %[[VAL_4:.*]] = shape.to_extent_tensor %[[VAL_3]]
|
||||
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = npcomprt.to_memref %[[VAL_0]] : memref<*xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_6]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: linalg.copy(%[[VAL_7]], %[[VAL_5]]) : memref<?xf32>, memref<?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = memref_cast %[[VAL_5]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = npcomprt.from_memref %[[VAL_8]] : memref<*xf32>
|
||||
// CHECK: return %[[VAL_9]] : !npcomprt.tensor
|
||||
// CHECK: }
|
||||
|
||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%shape = shape.shape_of %arg0 : tensor<?xf32>
|
||||
%shape = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||
%memref = tcp.alloc_memref %shape : memref<?xf32>
|
||||
tensor_store %arg0, %memref : memref<?xf32>
|
||||
%ret = tensor_load %memref : memref<?xf32>
|
||||
|
@ -40,7 +41,7 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
|
||||
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||
tcp.global @g dense<7.0> : tensor<10xf32>
|
||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||
func @gets_global() -> tensor<10xf32> {
|
||||
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
||||
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
// RUN: npcomp-opt -resolve-shape-of-ops <%s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: !shape.shape) -> !shape.shape {
|
||||
func @basic(%arg0: tensor<?xindex>) -> tensor<?xindex> {
|
||||
%memref = tcp.alloc_memref %arg0 : memref<?xf32>
|
||||
%tensor = tensor_load %memref : memref<?xf32>
|
||||
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> !shape.shape
|
||||
%shape = "shape.shape_of"(%tensor) : (tensor<?xf32>) -> tensor<?xindex>
|
||||
// CHECK: return %arg0
|
||||
return %shape : !shape.shape
|
||||
return %shape : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @arg_unresolved_ok
|
||||
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> !shape.shape {
|
||||
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> !shape.shape
|
||||
return %0 : !shape.shape
|
||||
func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> tensor<?xindex> {
|
||||
%0 = "shape.shape_of"(%arg0): (tensor<?xf32>) -> tensor<?xindex>
|
||||
return %0 : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -22,9 +22,9 @@ func @arg_unresolved_ok(%arg0: tensor<?xf32>) -> !shape.shape {
|
|||
// CHECK-LABEL: func @TODO_bb_arg_unresolved_not_ok
|
||||
// TODO: This should emit a diagnostic, but doesn't. Why?
|
||||
// addDynamicallyLegalOp isn't working as I expect.
|
||||
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> !shape.shape {
|
||||
func @TODO_bb_arg_unresolved_not_ok(%arg0: i1, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xindex> {
|
||||
cond_br %arg0, ^bb1(%arg1: tensor<?xf32>), ^bb1(%arg2: tensor<?xf32>)
|
||||
^bb1(%bbarg: tensor<?xf32>):
|
||||
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> !shape.shape
|
||||
return %0 : !shape.shape
|
||||
%0 = "shape.shape_of"(%bbarg): (tensor<?xf32>) -> tensor<?xindex>
|
||||
return %0 : tensor<?xindex>
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
|
||||
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> tensor<?xindex>
|
||||
|
||||
// CHECK: %[[SRCMEMREF:.+]] = tcp.alloc_memref
|
||||
%src_memref = tcp.alloc_memref %shape : memref<?xf32>
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
// RUN: -invoke scalar \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
|
||||
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tcf.add"(%arg0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue