//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is the base file for npcomp's "reference backend". // // The input to this backend is a layer that consists of linalg-on-tensors // together with std scalar ops and control flow. // // The output of this backend is LLVM IR suitable for JITing. // // We expect that other backends will appear that have a similar kind of // interface. IREE already uses this layering. // //===----------------------------------------------------------------------===// #include "npcomp/RefBackend/RefBackend.h" #include "PassDetail.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "npcomp/Dialect/Refback/IR/RefbackOps.h" using namespace mlir; using namespace mlir::NPCOMP; //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { #define GEN_PASS_REGISTRATION #include "npcomp/RefBackend/Passes.h.inc" } // end namespace void mlir::NPCOMP::registerRefBackendPasses() { ::registerPasses(); mlir::PassPipelineRegistration( "refback-lowering-pipeline", "RefBackend lowering pipeline.", mlir::NPCOMP::createRefBackendLoweringPipeline); } //===----------------------------------------------------------------------===// // LowerAllocMemRefOps //===----------------------------------------------------------------------===// namespace { class LowerAllocMemRefOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(refback::AllocMemRefOp op, PatternRewriter &rewriter) const override { auto memrefType = op.getType().cast(); auto shape = op.getOperand(); // std.alloc only accepts the dynamic extents as operands, so only // collect those. SmallVector dynamicExtents; for (int i = 0, e = memrefType.getRank(); i < e; i++) { if (memrefType.isDynamicDim(i)) { auto ci = rewriter.create(op.getLoc(), i); auto extent = rewriter.create(op.getLoc(), shape, ValueRange({ci})); dynamicExtents.push_back(extent); } } rewriter.replaceOpWithNewOp(op, memrefType, dynamicExtents); return success(); } }; } // namespace namespace { class LowerAllocMemRefOps : public LowerAllocMemRefOpsBase { void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); ConversionTarget target(*context); target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerAllocMemRefOpsPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // RestrictedCanonicalizer //===----------------------------------------------------------------------===// namespace { struct RestrictedCanonicalizer : public RestrictedCanonicalizerBase { void runOnOperation() override { auto *context = &getContext(); // Find the dialects from their names. DenseSet neededDialects; for (const std::string &dialectName : includedDialects) neededDialects.insert(dialectName); DenseSet dialectsToCanonicalize; for (Dialect *dialect : context->getLoadedDialects()) { if (neededDialects.count(dialect->getNamespace())) { dialectsToCanonicalize.insert(dialect); // Erase the dialect so that we can report an error below for any // dialect names that are not loaded. neededDialects.erase(dialect->getNamespace()); } } // Report a helpful error if a dialect is not found. auto missingDialects = llvm::to_vector<6>(neededDialects); if (!missingDialects.empty()) { llvm::sort(missingDialects); std::string buf; llvm::raw_string_ostream os(buf); llvm::interleaveComma(missingDialects, os); llvm::report_fatal_error("restricted-canonicalize: unknown dialects: " + os.str()); } // Collect all canonicalization patterns from ops in the included dialects. RewritePatternSet patterns(context); for (AbstractOperation *op : context->getRegisteredOperations()) if (dialectsToCanonicalize.count(&op->dialect)) op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; } // end anonymous namespace std::unique_ptr mlir::NPCOMP::createRestrictedCanonicalizerPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // createRefBackendLoweringPipeline //===----------------------------------------------------------------------===// void mlir::NPCOMP::createRefBackendLoweringPipeline( OpPassManager &pm, const RefBackendLoweringPipelineOptions &options) { // Convert all elementwise ops to linalg. // // Considering correctness, this lets us reuse the linalg bufferization, which // applies uniformly to all linalg structured ops. // // Also, converting to linalg herevopens up a lot of optimization // opportunities. pm.addNestedPass(createConvertElementwiseToLinalgPass()); if (options.optimize) { pm.addNestedPass(createLinalgElementwiseOpFusionPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } // Lower shape constraints before we enter tensor->memref conversion. // That is, we expand shape.cstr_* ops to eager error handling code. pm.addNestedPass(createConvertShapeConstraintsPass()); // Run shape canonicalizations. In particular, this erases shape.assuming, // now that we have converted shape constraints. // TODO: This is kind of ugly. Either we use pass options or a constructor // that takes C++ data structures. The former makes the pass usable on the // command line (including reproducers), the latter makes the pass more // convenient. std::unique_ptr shapeCanonicalizer = createRestrictedCanonicalizerPass(); if (failed(shapeCanonicalizer->initializeOptions("included-dialects=shape"))) llvm::report_fatal_error("couldn't initialize restricted-canonicalize"); pm.addPass(std::move(shapeCanonicalizer)); // Lower shape ops to std. pm.addPass(createConvertShapeToStandardPass()); // -------------------------------------------------------------------------- // Lower the `tensor` type to `memref`. // -------------------------------------------------------------------------- // We make a conscious effort here to do this as a sequence of separate passes // rather than a single mega dialect conversion pass. // // This means that intermediate steps have source/target materializations // (memref.tensor_load / memref.buffer_cast) in the IR. // Run tensor constant bufferization. // This pass has to run on a module op, and so does the final // FuncBufferizePass. But everything else can run in parallel on functions, // so we try to bracket the entire bufferization pipeline with the module // passes to allow maximum parallelism. pm.addPass(createTensorConstantBufferizePass()); // refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. // We need to resolve this to std.alloc which takes individual extents. pm.addNestedPass(createLowerAllocMemRefOpsPass()); pm.addNestedPass(createSCFBufferizePass()); pm.addNestedPass(createLinalgBufferizePass()); pm.addNestedPass(createStdBufferizePass()); pm.addNestedPass(createTensorBufferizePass()); pm.addPass(createFuncBufferizePass()); pm.addNestedPass(createFinalizingBufferizePass()); // TODO: Do buffer deallocation. We should be able to just drop in the // upstream pass? // At this point, we have lots of loose stuff floating around from lowering, // so it's a good time to do some general cleanups. if (options.optimize) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } // -------------------------------------------------------------------------- // Preparation for converting to an LLVM module. // -------------------------------------------------------------------------- // Now, we begin the process of lowering to LLVM's level of abstraction // (after which LLVM will take over lowering to machine code). // Lower linalg ops to loops. // TODO: Do some linalg optimizations like tiling here. pm.addNestedPass(createConvertLinalgToLoopsPass()); // Run a some cleanups. if (options.optimize) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } // -------------------------------------------------------------------------- // Final conversion to an LLVM module. // -------------------------------------------------------------------------- // Convert affine to std control flow in preparation for going to LLVM. pm.addNestedPass(createLowerAffinePass()); // Convert scf to std control flow in preparation for going to LLVM. pm.addNestedPass(createLowerToCFGPass()); // Convert functions signatures and other constructs that interface with the // runtime to the `refbackrt` dialect. pm.addPass(createLowerToRefbackrtABIPass()); // Finally, convert to LLVM dialect using our custom LowerToLLVM pass // which reuses the upstream patterns and gives us a place to add our own // patterns for our own custom ops like the refbackrt ops. pm.addPass(createLowerToLLVMPass()); // Although LLVM will clean everything up eventually, for the sake of IR // clarity while still in MLIR, run some cleanups. if (options.optimize) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } }