//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is the base file for our "end-to-end" npcomp lowering pipeline. // At the moment, the first "end" is TCF ops and the second "end" is `llvm` // dialect suitable for jitting. // // This is still work-in-progress and not even working end-to-end for the // most trivial examples, see TODO's in createE2ELoweringPipeline for the // status. // // As a pragmatic matter, I generally tend to drop random passes and stuff // inside this top-level file and then shard it out to separate files once // a clear organizing principle arises (to avoid premature organizing). // // Once we have end-to-end functionality working, we will throw // increasingly complex programs and augment this pass pipeline, likely // introducing better structure and more clear principles. // // I wish I had a clear view of how this pipeline should perfectly layer // ahead of time, but unfortunately I don't since it crosses half a dozen // abstraction levels / dialects, some of which have no precedent that I'm // aware of (dynamic-shape-aware, error-aware TCF -> TCP) or very little // (tensor -> memref/buffer with dynamic shapes, shape -> SSA values for // ranked shape extents). // // Right now there's lots of stuff in this pipeline that is limited to // special cases where I have an idea of how to elaborate it to the general // case. The priority is getting and end-to-end flow working that we can // grow out organically to a curriculum of more complex cases, elaborating // on the design principles and layering as necessitated by the curriculum. // // This should be fun :) // //===----------------------------------------------------------------------===// #include "npcomp/E2E/E2E.h" #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h" #include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; //===----------------------------------------------------------------------===// // ResolveShapeOfOps //===----------------------------------------------------------------------===// class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::ShapeOfOp op, PatternRewriter &rewriter) const override { if (auto tensorLoad = llvm::dyn_cast_or_null( op.getOperand().getDefiningOp())) { if (auto allocMemRef = llvm::dyn_cast_or_null( tensorLoad.getOperand().getDefiningOp())) { rewriter.replaceOp(op, allocMemRef.getOperand()); return success(); } } return failure(); } }; class ResolveShapeOfOps : public ResolveShapeOfOpsBase { void runOnOperation() { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); ConversionTarget target(*context); //target.addIllegalOp(); target.addDynamicallyLegalOp( [](shape::ShapeOfOp shapeOf) { // Only shape.shape_of on arguments to the entry block are legal at // this point. They are assumed to be resolved eventually via // the lowering of the tensor argument to some ABI that has the // relevant information available. But this is ABI dependent. // TODO: Convince myself that we never need to deal with general // block operands, or implement general handling of block // operands (need to add new bb operands of !shape.shape type). if (auto blockArg = shapeOf.getOperand().dyn_cast()) { Block *block = blockArg.getOwner(); if (&block->getParent()->front() == block) { return true; } } return false; }); if (failed(applyPartialConversion(func, target, patterns))) { return signalPassFailure(); } } }; std::unique_ptr> mlir::NPCOMP::createResolveShapeOfOpsPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // ResolveTensorLoadStoreOps //===----------------------------------------------------------------------===// class ReplaceTensorStoreWithCopyPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorStoreOp op, PatternRewriter &rewriter) const override { auto tensorLoad = llvm::dyn_cast_or_null(op.tensor().getDefiningOp()); if (!tensorLoad) return rewriter.notifyMatchFailure(op, "not fed by tensor_load op"); rewriter.replaceOpWithNewOp(op, tensorLoad.memref(), op.memref()); return success(); } }; class EraseUnusedTensorLoadOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorLoadOp op, PatternRewriter &rewriter) const override { if (!op.use_empty()) return rewriter.notifyMatchFailure(op, "has uses"); rewriter.eraseOp(op); return success(); } }; class ResolveTensorLoadStoreOps : public ResolveTensorLoadStoreOpsBase { void runOnOperation() { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); patterns.insert(context); ConversionTarget target(*context); target.addLegalDialect(); target.addDynamicallyLegalOp([](TensorLoadOp op) { for (auto user : op.getResult().getUsers()) if (!isa(user)) return false; return true; }); target.addDynamicallyLegalOp( [](TensorStoreOp op) { return op.tensor().isa(); }); if (failed(applyPartialConversion(func, target, patterns))) { return signalPassFailure(); } } }; std::unique_ptr> mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // LowerLinalgLoopDimOps //===----------------------------------------------------------------------===// class LowerLinalgLoopDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DimOp op, PatternRewriter &rewriter) const override { auto allocMemRef = op.getOperand().getDefiningOp(); if (!allocMemRef) return rewriter.notifyMatchFailure(op, "could not find alloc_memref"); rewriter.replaceOpWithNewOp(op, allocMemRef.shape(), op.index()); return success(); } }; class LowerLinalgLoopDimOps : public LowerLinalgLoopDimOpsBase { void runOnOperation() { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); ConversionTarget target(*context); target.addIllegalOp(); target.addLegalOp(); if (failed(applyPartialConversion(func, target, patterns))) { return signalPassFailure(); } } }; std::unique_ptr> mlir::NPCOMP::createLowerLinalgLoopDimOpsPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // createE2ELoweringPipeline //===----------------------------------------------------------------------===// void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) { // Input IR is TCF ops. // Convert to TCP. pm.addPass(createConvertTCFToTCPPass()); // TODO: Do tcp.island coarsening here. // -------------------------------------------------------------------------- // Tensor to buffer (memref) conversion. // -------------------------------------------------------------------------- // Convert tcp ops to Linalg where possible, as we want generic linalg // tensor->memref to do most of the mechanical work of rewriting ops in // terms of tensors to ops in terms of memrefs (since it is easy on that // representation). pm.addPass(createConvertTCPToLinalgPass()); // Lower to hybrid tensor/memref // // The hybrid tensor/memref representation guarantees: // - every use of a tensor is a tensor_store op writing it into a memref // - every def of a tensor is a tensor_load op loading out of some memref. // - every memref is allocated by a `tcp.alloc_memref(%shape)` op. // - every memref is only ever writen once, and never mutated // // Exceptions: "boundaries" such as function arguments and island // live-outs. // // Or, another way to say this: the hybrid tensor/memref representation // doesn't attempt to eliminate the original tensors from the program, // but rather locally expands operations on tensors to be small subgraphs // with tensor_load/tensor_store at the boundaries, leaving enough // invariants that we can clean it up later. // // The core invariants that are needed for this step are that the // tensor-level ops we receive as input have a way of calculating the // sizes for their outputs. This is equivalent to saying that // `shape.shape_of` on the result of an op must be calculatable in terms // of the shapes of the inputs to the op. createLowerToHybridTensorMemRefPipeline(pm); // At this point, the invariants of the hybrid tensor/memref // representation allow us to resolve `shape.shape_of` ops to shape // computations earlier in the program. Specifically, every // `shape.shape_of` can be resolved to the shape argument to the // corresponding `tcp.alloc_memref` op of the tensor_load that produced // that tensor. pm.addPass(createResolveShapeOfOpsPass()); // Now, we use the hybrid tensor/memref invariants to replace the // tensor_store ops with memref copy operations and erase the // tensor_load/tensor_store ops. pm.addPass(createResolveTensorLoadStoreOpsPass()); // At this point, the IR is in a form where the interiors of islands // don't have tensor ops (except tensor_store's of arguments and // tensor_load's of returns). // // This is a reasonable representation for doing buffer assignment. // // TODO: Might want a different kind of island to better represent this. // This island op would explicitly capture all tensors as inputs, and it // would establish a more formalized ABI with the interior of the body // region (much like IREE does with dispatch regions). For now, we are // planning on just inlining the islands, so there is little value in // doing this, but we should look at the layering aspects here later. // At this point, we have loose shape calculations floating around, so // it's a good time to do some general cleanups. pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); // -------------------------------------------------------------------------- // Lowering down to LLVM // -------------------------------------------------------------------------- // Now, we begin the process of lowering to LLVM's level of abstraction // (after which LLVM will take over lowering to machine code). // Lower linalg ops to loops. // TODO: Do some linalg optimizations like tiling here. pm.addPass(createConvertLinalgToLoopsPass()); // Lowering linalg to loops introduces `dim` ops. Here we look through // use-def chains to find `tcp.alloc_memref` ops that we can get a shape // out of. // Currently, this is trivial, but after more aggressive buffer // allocation optimizations or linalg tiling this step will need to look // through slices/views and stuff. // TODO: It seems that "dim on memrefs" is being resolved in a // fundamentally different way from "dim on tensors" is earlier in the // pipeline. Investigate. // We could somewhat unify them by having enough folding patterns for // `shape.shape_of`. Above, we used the pattern // "shape_of(tensor_load(alloc_memref(%shape))) -> %shape". Here we are // doing `shape_of(alloc_memref(%shape)) -> %shape". It seems // dangerous to just have a pile of these patterns and hope that one of // them resolves things at any given point. So what we do is to use a // very narrowly focused set of patterns that exploit just the invariants // at each point. pm.addPass(createLowerLinalgLoopDimOpsPass()); // Lower shapes to SSA values. // This replaces all tcf::GetExtentOp's with explicit SSA computations // for the scalar extent. This requires shapes which are ranked. Any // unranked shapes will need to be handled by a runtime shape type, // though we don't currently support that. // // At this point, in the case of programs with only ranked shapes, all // !shape.shape types will be gone. // TODO: Better demarcate the invariants here, such as having a verifier // pass that checks no !shape.shape types left. pm.addPass(createLowerRankedShapesPass()); // TODO: // Convert all of it to LLVM? }