//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "npcomp/E2E/E2E.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" using namespace mlir; using namespace mlir::NPCOMP; namespace { class LowerCstrBroadcastableOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, PatternRewriter &rewriter) const override { // A shape.cstr_* op should be the result of lowering a !shape.shape; it // should not itself ever consume or produce a !shape.shape. // // There is no way to "sink" a !shape.shape type, because one cannot inspect // if it is an error. The only way to use it safely is to lower the op that // produced the value to a set of constraints and then use the witness to // guard a shape.assuming. // // Consider for example what we do when lowering TCF to TCP: we need to do a // shape calculation for the broadcasting. But we create the // shape.cstr_broadcastable and use its witness to guard a `shape.assuming { // ... shape.broadcast ...}`. There's never any need to create a // !shape.shape. // // The use of !shape.shape should be restricted to contexts like // declarations of shape transfer functions, with automatic utilities to // lower !shape.shape types to corresponding constraints + shape.assuming + // tensors. In this (npcomp e2e) lowering flow, we don't have any such // "declarative shape transfer functions" or utilities to expand them to // constraints. So !shape.shape should never exist in our IR. // // Historically, we used !shape.shape type for everything, and // shape.to_extent_tensor would abort in case of an error. But that's not a // useful semantics for lowering, since the error is defined to happen as // part of the shape.to_extent_tensor op, which requires materializing an // "is error" bit in the IR and carrying it around everywhere that the // original !shape.shape value was being used. In practice, nobody respects // that, which opens us up to miscompilations. That is, the lowering // strategy is either "not emit errors at all" or "emit errors as part of // lowering e.g. the shape.broadcast op itself" (which technically puts the // errors in some random location in the IR that is not the // shape.to_extent_tensor op). E.g. the following code would miscompile with // either of those ways that these ops get lowered in practice: // ``` // %shape = shape.broadcast %lhs, %rhs : !shape.shape // if %cond: // shape.to_extent_tensor(%shape) // ``` // It's not possible to correctly compile this code without significant // contortions (such as carrying an "is error" bit). And to boot, we // shouldn't be getting into that situation in the first place! But the // `shape.to_extent_tensor : !shape.shape -> tensor` abstraction // opens up that possibility. // // shape.to_extent_tensor should not really be a thing, since it creates // these ill-defined situations about where errors are observed. A // !shape.shape type should only exist (for this compilation flow) as part // of a utility, something like "I want to do this shape calculation on // !shape.shape type, create IR that uses tensor and witnesses to // implement it, on the assumption that the error can be // observed anywhere inside the shape calculation". // // !shape.shape type would still be useful for lowerings that actually // result in a runtime type that carries an "is error" bit inside it, though // TBD if such use cases arise. if (op.getType().isa() || op.lhs().getType().isa() || op.rhs().getType().isa()) { return op.emitError() << "Error shapes should not exist at this point"; } auto loc = op.getLoc(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); // Find smaller and greater rank and extent tensor. Value lhsRank = rewriter.create(loc, op.lhs(), zero); Value rhsRank = rewriter.create(loc, op.rhs(), zero); Value lhsSmaller = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); Type extentTensorTy = op.lhs().getType(); auto ifOp = rewriter.create( loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, lhsSmaller, [&](OpBuilder &b, Location loc) { b.create( loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()}); }, [&](OpBuilder &b, Location loc) { b.create( loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()}); }); Value lesserRank = ifOp.getResult(0); Value lesserRankOperand = ifOp.getResult(1); Value greaterRank = ifOp.getResult(2); Value greaterRankOperand = ifOp.getResult(3); Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); // Compare the shapes extent by extent, and emit errors for // non-broadcast-compatible shapes. // Two extents are broadcast-compatible if // 1. they are both equal, or // 2. at least one of them is 1. rewriter.create( loc, rankDiff, greaterRank, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange) { Value greaterRankOperandExtent = b.create( loc, greaterRankOperand, ValueRange{iv}); Value ivShifted = b.create(loc, indexTy, iv, rankDiff); Value lesserRankOperandExtent = b.create( loc, lesserRankOperand, ValueRange{ivShifted}); Value greaterRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, greaterRankOperandExtent, one); Value lesserRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, lesserRankOperandExtent, one); Value extentsAgree = b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, lesserRankOperandExtent); auto broadcastIsValid = b.create(loc, b.getI1Type(), extentsAgree, b.create(loc, greaterRankOperandExtentIsOne, lesserRankOperandExtentIsOne)); b.create(loc, broadcastIsValid, "invalid broadcast"); b.create(loc); }); // Now that we have emitted all the assertions, the witness is trivially // satisfied. rewriter.replaceOpWithNewOp(op, true); return success(); } }; } // namespace namespace { class LowerCstrRequireOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { rewriter.create(op.getLoc(), op.pred(), op.msgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } }; } // namespace namespace { // This pass eliminates shape constraints from the program. // // After this pass finishes, there are no !shape.witness types in the program, // no shape.assuming, no shape.cstr_*. // // TODO: This should move to upstream ShapeToStandard conversions. class LowerShapeConstraints : public LowerShapeConstraintsBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); patterns.insert(context); // Add in the canonicalization patterns for shape.assuming so that it gets // inlined when its witness becomes a true constant witness. shape::AssumingOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(func, patterns))) return signalPassFailure(); // TODO: Check that there are no remaining !shape.witness, shape.assuming, // shape.cstr_* ops, etc. } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerShapeConstraintsPass() { return std::make_unique(); }