//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "npcomp/RefBackend/RefBackend.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h" #include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h" using namespace mlir; using namespace mlir::NPCOMP; namespace { class LowerExtractElementOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ExtractElementOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ExtractElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), adaptor.indices()); return success(); } }; } // namespace namespace { class LowerTensorFromElementsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorFromElementsOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { int numberOfElements = op.elements().size(); auto resultType = MemRefType::get( {numberOfElements}, op.getType().cast().getElementType()); Value result = rewriter.create(op.getLoc(), resultType); for (auto element : llvm::enumerate(op.elements())) { Value index = rewriter.create(op.getLoc(), element.index()); rewriter.create(op.getLoc(), element.value(), result, index); } rewriter.replaceOp(op, {result}); return success(); } }; } // namespace namespace { class LowerTensorCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorCastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = typeConverter->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, operands[0]); return success(); } }; } // namespace namespace { class LowerTensorLoadOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorLoadOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, operands[0]); return success(); } }; } // namespace namespace { // TODO: Upstream this. class LowerStdToMemref : public LowerStdToMemrefBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](RankedTensorType type) -> Type { return MemRefType::get(type.getShape(), type.getElementType()); }); typeConverter.addSourceMaterialization([](OpBuilder &builder, RankedTensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return (Value)builder.create(loc, type, inputs[0]); }); typeConverter.addTargetMaterialization([](OpBuilder &builder, MemRefType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return (Value)builder.create(loc, type, inputs[0]); }); OwningRewritePatternList patterns; ConversionTarget target(*context); target.addLegalDialect(); // The casting ops are introduced by the type converter, so they must be // legal. target.addLegalOp(); target.addLegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createLowerStdToMemrefPass() { return std::make_unique(); }