//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "npcomp/E2E/E2E.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; namespace { class LowerConstShapeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::ConstShapeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto extents = llvm::to_vector<6>(llvm::map_range( op.shape().getValues(), [&](int64_t extent) -> Value { return rewriter.create(op.getLoc(), extent); })); rewriter.replaceOpWithNewOp( op, rewriter.getType(), extents); return success(); } }; } // namespace namespace { // Given an operand that is either a Shape or Extent Tensor, returns an // Extent Tensor or nullptr if this cannot be locally determined. // The return value, if !nullptr, will be a 1D RankedTensorType (with possibly // unknown element). Value findExtentsFromShape(Value operand, bool requireKnownRank) { if (auto tensorType = operand.getType().dyn_cast()) { if (tensorType.getRank() == 1 && (!requireKnownRank || tensorType.hasStaticShape())) { return operand; } } return nullptr; } class LowerShapeBroadcastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { shape::BroadcastOp::Adaptor adaptor(operands); // When the ranks are statically known, generate non-branchy code. // TODO: Generate rank-generic code. auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true); auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true); if (!lhsExtents || !rhsExtents) return rewriter.notifyMatchFailure(op, "dynamic extents not supported"); // Establish invariant that rank(lhs) >= rank(rhs). auto lhsSize = lhsExtents.getType().cast().getDimSize(0); auto rhsSize = rhsExtents.getType().cast().getDimSize(0); if (lhsSize < rhsSize) { std::swap(lhsExtents, rhsExtents); std::swap(lhsSize, rhsSize); } auto rankDiscrepancy = lhsSize - rhsSize; // Helper that creates IR // ``` // abort_if(extent != resultExtent && extent != 1) // ``` // This is the numpy broadcasting legality check. auto createAbortIfIllegalBroadcastExtent = [&](Value extent, Value resultExtent) { auto c1 = rewriter.create(op.getLoc(), 1); auto extentNeMax = rewriter.create(op.getLoc(), CmpIPredicate::ne, extent, resultExtent); auto extentNeOne = rewriter.create(op.getLoc(), CmpIPredicate::ne, extent, c1); auto bothTrue = rewriter.create(op.getLoc(), extentNeMax, extentNeOne); // TODO: Should there be a more generic error-handling dialect? // It seems a bit awkward to hardcode npcomprt here. rewriter.create(op.getLoc(), bothTrue); }; SmallVector resultExtents; for (int i = 0, e = lhsSize; i < e; i++) { auto lhsDim = rewriter.create(op.getLoc(), i); auto lhsExtent = rewriter.create( op.getLoc(), lhsExtents, ValueRange{lhsDim}); if (i < rankDiscrepancy) { // Padded extent. resultExtents.push_back(lhsExtent); continue; } // Non-padded extent. auto rhsDim = rewriter.create(op.getLoc(), i - rankDiscrepancy); auto rhsExtent = rewriter.create( op.getLoc(), rhsExtents, ValueRange{rhsDim}); auto ugt = rewriter.create(op.getLoc(), CmpIPredicate::ugt, lhsExtent, rhsExtent); auto resultExtent = rewriter.create(op.getLoc(), ugt, lhsExtent, rhsExtent); createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent); createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent); resultExtents.push_back(resultExtent); } // TODO: Remove the return type once ODS is fixed to do proper inference. rewriter.replaceOpWithNewOp( op, shape::ShapeType::get(rewriter.getContext()), resultExtents); return success(); } }; } // namespace namespace { class LowerShapeToExtentTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { shape::ToExtentTensorOpAdaptor adaptor(operands); if (adaptor.input().getType().isa()) { // Convert by matching to a producing FromExtentsOp. auto fromExtents = adaptor.input().getDefiningOp(); if (!fromExtents) { return rewriter.notifyMatchFailure(op, "not a from_extents op"); } rewriter.replaceOpWithNewOp(op, fromExtents.extents()); return success(); } // Assume that it is already an extent tensor. // TODO: Since these ops are all multi-type, there should be a utility // for switching on the allowable types instead of just assuming that it // is an extent tensor. rewriter.replaceOp(op, adaptor.input()); return success(); } }; class LowerShapeGetExtentOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::GetExtentOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { shape::GetExtentOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.shape(), adaptor.dim()); return success(); } }; } // namespace namespace { // Now that we have lowered ranked shapes, which reifies the eager // error-handling code, the tcp::ShapeObserveErrorOp's are no longer // needed. class EraseShapeObserveErrorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; } // namespace // Basic invariant of this pass: // Every `shape.from_extents` op operating on an extent tensor // (`tensor`) is replaced by corresponding standard ops and folded // away (for the ranked case, it should be possible to eliminate these). // // We expect that previous passes have inserted a "root" set of // shape::FromExtentsOp's that allow this process to get started. // // This is similar to the approach that is used in IREE. It is basically a // combination of the ConvertShapeToShapex pass and the // "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern. // These patterns have to be "conversion patterns" since the `operands` argument // gives access to the post-conversion operands from earlier ops. // // This pass depends heavily on ranked shapes, since only ranked shapes can // be statically expanded to a fixed set of SSA extents. // // TODO: This approach doesn't naively work with control flow. // In the presence of non-cyclic control flow, we can just generalize the // `getDefiningOp()` calls into something that will // look through block arguments and rewrite "phi of shapes -> phi of extents". // In the presence of cyclic control flow, we need to somehow resolve the // ranks of use-def cycles ahead of time or optimistically assume that // backedges will match the rank of forward edges, and somehow be robust // when that assumption fails. // // TODO: Add in a fold of // `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore // the above invariant without relying on a subsequent canonicalization // step. namespace { class LowerRankedShapes : public LowerRankedShapesBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); patterns.insert(context); patterns.insert(context); patterns.insert(context); patterns.insert(context); ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addLegalOp(); target.addIllegalOp(); target.addLegalOp(); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) { return signalPassFailure(); } // Erase some stray shape ops from the program. They can't be // deleted during conversion because they become unused only after // subsequent patterns bypass them. auto walkResult = func.walk([](Operation *op) { if (!isa(op)) return WalkResult::advance(); if (op->use_empty()) { op->erase(); } else { op->emitError("could not be eliminated"); return WalkResult::interrupt(); } return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerRankedShapesPass() { return std::make_unique(); }