//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h" #include "../PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace NPCOMP; namespace { class ConvertAdd : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tcp::AddOp op, PatternRewriter &rewriter) const override { size_t rank = op.getType().cast().getRank(); SmallVector iterators(rank, getParallelIteratorTypeName()); SmallVector accesses(/*args in + args out*/ 3, rewriter.getMultiDimIdentityMap(rank)); auto genericOp = rewriter.create( op.getLoc(), llvm::makeArrayRef({op.getType()}), ValueRange({op.lhs(), op.rhs()}), /*args_in=*/2, /*args_out=*/1, /*indexing_maps=*/accesses, /*iterator_types=*/iterators, /*function_ref=*/nullptr); Region ®ion = genericOp.region(); Block *block = rewriter.createBlock(®ion, region.begin()); for (auto operandType : op.getOperandTypes()) { block->addArgument(operandType.cast().getElementType()); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(block); Value bodyValue = rewriter.create( op.getLoc(), block->getArgument(0), block->getArgument(1)); rewriter.create(op.getLoc(), bodyValue); rewriter.replaceOp(op, genericOp.getResults()); return success(); } }; } // namespace namespace { class ConvertTCPToLinalg : public ConvertTCPToLinalgBase { public: void runOnOperation() override { ModuleOp module = getOperation(); MLIRContext *context = &getContext(); ConversionTarget target(*context); OwningRewritePatternList patterns; patterns.insert(context); target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); if (failed(applyPartialConversion(module, target, patterns))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::NPCOMP::createConvertTCPToLinalgPass() { return std::make_unique(); }