//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "npcomp/RefBackend/RefBackend.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Refback/IR/RefbackDialect.h" #include "npcomp/Dialect/Refback/IR/RefbackOps.h" using namespace mlir; using namespace mlir::NPCOMP; namespace { class LowerExtractElementOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ExtractElementOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ExtractElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), adaptor.indices()); return success(); } }; } // namespace namespace { class LowerTensorFromElementsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorFromElementsOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { int numberOfElements = op.elements().size(); auto resultType = MemRefType::get( {numberOfElements}, op.getType().cast().getElementType()); Value result = rewriter.create(op.getLoc(), resultType); for (auto element : llvm::enumerate(op.elements())) { Value index = rewriter.create(op.getLoc(), element.index()); rewriter.create(op.getLoc(), element.value(), result, index); } rewriter.replaceOp(op, {result}); return success(); } }; } // namespace namespace { class LowerTensorCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorCastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = typeConverter->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, operands[0]); return success(); } }; } // namespace namespace { // TODO: Upstream this. class LowerStdToMemref : public LowerStdToMemrefBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); BufferizeTypeConverter typeConverter; OwningRewritePatternList patterns; ConversionTarget target(*context); target.addLegalDialect(); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerStdToMemrefPass() { return std::make_unique(); }