//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "npcomp/E2E/E2E.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/InliningUtils.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h" #include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; static FailureOr> allocateResults(Operation *op, ConversionPatternRewriter &rewriter, Location loc, SmallVectorImpl *resultShapesOut = nullptr) { // TODO: This is really fragile. Can we have a better story? auto shapedResults = dyn_cast(op->getParentOp()); if (!shapedResults) return rewriter.notifyMatchFailure(op, "parent not tcp.shaped_results"); if (op->getResults() != shapedResults.getBody()->getTerminator()->getOperands()) return rewriter.notifyMatchFailure( op, "only limited forms of tcp.shaped_results allowed"); auto resultShapes = shapedResults.resultShapes(); SmallVector results; for (auto t : llvm::zip(op->getResults(), resultShapes)) { auto result = std::get<0>(t); auto resultShape = std::get<1>(t); auto tensorType = result.getType().cast(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto memref = rewriter.create(loc, memrefType, resultShape); results.push_back(memref); } if (resultShapesOut) resultShapesOut->append(resultShapes.begin(), resultShapes.end()); return results; } namespace { // TODO: Lower to a "buffer version" of tcp::BroadcastTo instead of directly to // loops. class LowerBroadcastToToLoopsPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tcp::BroadcastToOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = op.getType().cast(); auto inputType = op.operand().getType().cast(); SmallVector resultShapes; auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc(), &resultShapes); if (failed(resultsOrFailure)) return failure(); Value resultMemref = (*resultsOrFailure)[0]; auto resultShape = resultShapes[0]; Value inputMemref = operands[0]; SmallVector outputExtents; for (int i = 0, e = resultType.getRank(); i < e; i++) { Value dimIndex = rewriter.create(op.getLoc(), i); Value outputExtent = rewriter.create( op.getLoc(), rewriter.getIndexType(), resultShape, dimIndex); outputExtents.push_back(outputExtent); } int rankDiff = resultType.getRank() - inputType.getRank(); SmallVector inputDimRequiresBroadcasting; for (int i = 0, e = inputType.getRank(); i < e; i++) { // Calculate the relevant extents. Value inputExtent = rewriter.create(op.getLoc(), op.operand(), i); inputDimRequiresBroadcasting.push_back( rewriter.create(op.getLoc(), CmpIPredicate::ne, inputExtent, outputExtents[rankDiff + i])); } { OpBuilder::InsertionGuard guard(rewriter); Value c0 = rewriter.create(op.getLoc(), 0); Value c1 = rewriter.create(op.getLoc(), 1); SmallVector inductionVariables; // Create the (perfectly nested) loops. // Loop invariant: At the start of iteration `i`, the rewriter insertion // point is inside `i` nested loops. for (int i = 0, e = resultType.getRank(); i < e; i++) { auto loop = rewriter.create( op.getLoc(), c0, outputExtents[i], c1, ValueRange({})); Block *body = loop.getBody(); inductionVariables.push_back(body->getArgument(0)); // Leave the insertion point at the beginning of the body. rewriter.setInsertionPointToStart(body); } // Create the inner loop body. // When reading from the input, clamp any indices for dimensions that are // being broadcast. SmallVector inputIndices; for (int i = 0, e = inputType.getRank(); i < e; i++) { auto c0 = rewriter.create(op.getLoc(), 0); auto select = rewriter.create( op.getLoc(), inputDimRequiresBroadcasting[i], c0, inductionVariables[rankDiff + i]); inputIndices.push_back(select); } Value load = rewriter.create(op.getLoc(), inputMemref, inputIndices); rewriter.create(op.getLoc(), load, resultMemref, inductionVariables); } rewriter.replaceOp(op, resultMemref); return success(); } }; } // namespace namespace { class LowerLinalgGenericTensorToMemRef : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(linalg::GenericOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // TODO: Replace this with more generic code operating on named // structured ops too. // These checks mirror those in BypassShapes. if (!llvm::all_of(op.getOperandTypes(), [](Type type) { return type.isa(); })) { return rewriter.notifyMatchFailure(op, "all operands must be tensors"); } if (!llvm::all_of(op.getResultTypes(), [](Type type) { return type.isa(); })) { return rewriter.notifyMatchFailure(op, "all results must be tensors"); } if (!llvm::all_of(op.indexing_maps(), [](Attribute map) { return map.cast().getValue().isIdentity(); })) { return rewriter.notifyMatchFailure( op, "all indexing maps must be identity maps"); } if (!llvm::all_of(op.iterator_types(), [](Attribute str) { return str.cast().getValue() == getParallelIteratorTypeName(); })) { return rewriter.notifyMatchFailure( op, "all iterator types must be 'parallel'"); } SmallVector memrefs(operands.begin(), operands.end()); auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); if (failed(resultsOrFailure)) return failure(); auto results = *resultsOrFailure; memrefs.append(results.begin(), results.end()); auto newGeneric = rewriter.create( op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs()); newGeneric.region().getBlocks().clear(); BlockAndValueMapping mapper; op.region().cloneInto(&newGeneric.region(), mapper); for (auto memref : results) { newGeneric.region().front().addArgument( memref.getType().cast().getElementType()); } rewriter.replaceOp(op, results); return success(); } }; } // namespace namespace { class LowerTcpMatmulOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tcp::MatmulOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); if (failed(resultsOrFailure)) return failure(); auto results = *resultsOrFailure; rewriter.create(op.getLoc(), operands, results); rewriter.replaceOp(op, results); return success(); } }; } // namespace namespace { // TODO: Linalg and shape don't implement the inliner interface, which blocks us // from using mlir::inlineRegion. Locally override it here. class LocallyOverrideLegalityInlinerInterface : public InlinerInterface { public: using InlinerInterface::InlinerInterface; bool isLegalToInline(Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const final { return true; } bool isLegalToInline(Region *dest, Region *src, BlockAndValueMapping &valueMapping) const final { return true; } }; } // namespace namespace { // This pass is responsible for lowering regions wrapped by // tcp.shaped_results (which operate on tensors) to memrefs. // This includes any ops potentially contained within them. // This is somewhat analogous to IREE's backend compilation of a single dispatch // region, except that for now, we only allow a single op in the // tcp.shaped_results, and we don't have any notion of "backend" layered at all. // Nor is it clear if we really want any of that here. // // The tcp.shaped_results ops provide precisely the information needed to // allocate output buffers when converting to memref. // For now, this process eliminates the original tcp.shaped_results op since we // don't have any host/device distinction or other structure that would require // retaining that sort of IR structure. // // TODO: Do "shape_of" resolution while still on tensors. // Here we spew out tons of shape_of and rely on dim ops on descriptors to make // it work. The key difference is that we need tcp.shaped_results (or its // successor / something it gets lowered to) to not be IsolatedFromAbove, and // explicitly capture all input tensors along with their shapes. That allows // shape_of ops on inputs to be trivially resolved. Unfortunately, this opens up // the whole "dispatch region formation" can of worms like exists in IREE -- // once you have multiple ops inside a "dispatch region", you need to somehow // lower them without allocating intermediate buffers. // // TODO: Don't hardcode the lowering for every op in this one pass. class LowerShapedResultsToMemref : public LowerShapedResultsToMemrefBase { void runOnOperation() { auto func = getOperation(); auto *context = &getContext(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](RankedTensorType type) -> Type { return MemRefType::get(type.getShape(), type.getElementType()); }); typeConverter.addSourceMaterialization([](OpBuilder &builder, RankedTensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return (Value)builder.create(loc, type, inputs[0]); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, MemRefType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return (Value)builder.create(loc, type, inputs[0]); }); OwningRewritePatternList patterns; ConversionTarget target(*context); // The shaped results ops themselves. They have to be legal since we delete // them later after the conversion process. target.addLegalOp(); target.addLegalOp(); // All lowering to buffers involves tcp.alloc_memref ops. target.addLegalOp(); // The casting ops are introduced by the type converter, so we should mark // them legal. target.addLegalOp(); target.addLegalOp(); patterns.insert(typeConverter, context); target.addDynamicallyLegalOp([](linalg::GenericOp op) { if (llvm::any_of(op.getOperandTypes(), [](Type type) { return type.isa(); })) { return false; } if (llvm::any_of(op.getResultTypes(), [](Type type) { return type.isa(); })) { return false; } return true; }); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); SmallVector shapedResultsOps; func.walk([&](tcp::ShapedResultsOp op) { shapedResultsOps.push_back(op); }); if (failed(applyFullConversion(shapedResultsOps, target, patterns))) return signalPassFailure(); // Now inline the tcp.shaped_results ops. // This can't be done as part of the conversion since conversion visits // ops in preorder, and we need the tcp.shaped_results ops to be present // so that inner ops can get their shape. LocallyOverrideLegalityInlinerInterface interface(context); for (Operation *shapedResultsOp : shapedResultsOps) { auto op = cast(shapedResultsOp); if (failed(inlineRegion(interface, &op.body(), op, ValueRange({}), op.getResults(), /*inlineLoc=*/llvm::None, /*shouldCloneInlinedRegion=*/false))) { op.emitError() << "could not inline body"; return signalPassFailure(); } op.erase(); } } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerShapedResultsToMemrefPass() { return std::make_unique(); }