//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "npcomp/E2E/E2E.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Verifier.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; //===----------------------------------------------------------------------===// // Creating module metadata. //===----------------------------------------------------------------------===// // Returns true if the function signature can be expressed with the npcomprt // ABI. static bool expressibleWithNpcomprtABI(FunctionType type) { // Currently, only tensor types can be exposed at npcomprt ABI boundaries. return llvm::all_of( llvm::concat(type.getInputs(), type.getResults()), [](Type t) { return t.isa(); }); } static LogicalResult createModuleMetadata(ModuleOp module) { auto moduleMetadata = OpBuilder::atBlockBegin(module.getBody()) .create(module.getLoc()); moduleMetadata.metadatas().push_back(new Block); Block &metadatas = moduleMetadata.metadatas().front(); OpBuilder::atBlockEnd(&metadatas) .create(module.getLoc()); SymbolTable symbolTable(module); auto builder = OpBuilder::atBlockBegin(&metadatas); for (auto func : module.getOps()) { if (symbolTable.getSymbolVisibility(func) != SymbolTable::Visibility::Public) { continue; } // TODO: Add richer information here such as expected shapes and element // types. builder.create( func.getLoc(), builder.getSymbolRefAttr(func.getName()), builder.getI32IntegerAttr(func.getNumArguments()), builder.getI32IntegerAttr(func.getNumResults())); if (!expressibleWithNpcomprtABI(func.getType())) return func.emitError() << "func not expressible with npcomprt ABI"; } return success(); } //===----------------------------------------------------------------------===// // Dialect conversion. //===----------------------------------------------------------------------===// namespace { class LowerTensorStoreOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TensorStoreOp::Adaptor adaptor(operands); auto memrefType = op.memref().getType().cast(); Value abiMemref = rewriter.create( op.getLoc(), UnrankedMemRefType::get(memrefType.getElementType(), /*memorySpace=*/0), adaptor.tensor()); auto memref = rewriter.create(op.getLoc(), abiMemref, memrefType); rewriter.replaceOpWithNewOp(op, memref, adaptor.memref()); return success(); } }; } // namespace namespace { class LowerTensorLoadOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorLoadOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TensorLoadOp::Adaptor adaptor(operands); auto abiMemref = rewriter.create( op.getLoc(), adaptor.memref(), UnrankedMemRefType::get( adaptor.memref().getType().cast().getElementType(), /*memorySpace=*/0)); rewriter.replaceOpWithNewOp( op, rewriter.getType(), abiMemref); return success(); } }; } // namespace namespace { class LowerShapeOfOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { shape::ShapeOfOp::Adaptor adaptor(operands); // TODO: For now npcomp only supports ranked tensor types for its shape // lowering, since we don't have a runtime shape struct and lower all shapes // to individual SSA values. auto tensorType = op.arg().getType().cast(); SmallVector extents; for (int i = 0, e = tensorType.getRank(); i < e; i++) { auto ci = rewriter.create(op.getLoc(), rewriter.getI32IntegerAttr(i)); // TODO: Shouldn't the index type for the output be inferred since // https://reviews.llvm.org/rG31f40f603d0c00b313397196124c5f39090badf0 // ? extents.push_back(rewriter.create( op.getLoc(), rewriter.getIndexType(), adaptor.arg(), ci)); } auto newShape = rewriter.create( op.getLoc(), rewriter.getType(), extents); // TODO: Provide a builder that doesn't require the result type. rewriter.replaceOpWithNewOp( op, RankedTensorType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()), newShape); return success(); } }; } // namespace namespace { class LowerGlobalOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tcp::GlobalOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.sym_name(), op.value()); return success(); } }; } // namespace namespace { class LowerGetGlobalMemrefOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto abiMemrefType = UnrankedMemRefType::get( op.getType().cast().getElementType(), /*memorySpace=*/0); auto abiMemref = rewriter.create( op.getLoc(), abiMemrefType, op.global()); // Cast back to the original type. rewriter.replaceOpWithNewOp(op, abiMemref, op.getType()); return success(); } }; } // namespace static LogicalResult doDialectConversion(ModuleOp module) { auto *context = module.getContext(); TypeConverter converter; converter.addConversion([](TensorType type) { return npcomprt::TensorType::get(type.getContext()); }); converter.addConversion([](npcomprt::TensorType type) { return type; }); OwningRewritePatternList patterns; ConversionTarget target(*context); populateFuncOpTypeConversionPattern(patterns, context, converter); target.addDynamicallyLegalOp([&](mlir::FuncOp op) { return converter.isSignatureLegal(op.getType()); }); patterns.insert(context); target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); patterns.insert(context); target.addIllegalOp(); target.addLegalOp(); patterns.insert(context); target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); patterns.insert(context); target.addIllegalOp(); target.addLegalOp(); patterns.insert(context); target.addIllegalOp(); target.addLegalOp(); return applyPartialConversion(module, target, patterns); } namespace { // This pass lowers the public ABI of the module to the primitives exposed by // the npcomprt dialect. class LowerToNpcomprtABI : public LowerToNpcomprtABIBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { ModuleOp module = getOperation(); // Before we lower anything, capture any needed metadata about the argument // lists that will be needed for safely invoking the raw runtime functions // later. (for example, number of expected arguments/results, types, // etc.) if (failed(createModuleMetadata(module))) return signalPassFailure(); // Now do the actual conversion / lowering. if (failed(doDialectConversion(module))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerToNpcomprtABIPass() { return std::make_unique(); }