//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Conversion/TorchToIREE/TorchToIREE.h" #include "../PassDetail.h" #include "iree-dialects/Dialect/IREE/IREEOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Torch; //===----------------------------------------------------------------------===// // The patterns //===----------------------------------------------------------------------===// namespace { class ConvertPrimListConstructOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrimListConstructOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = getTypeConverter()->convertType(op.getType()); auto capacity = rewriter.create(op.getLoc(), op->getNumOperands()); auto ireeList = rewriter.replaceOpWithNewOp(op, type, capacity); for (int i = 0, e = operands.size(); i != e; ++i) { auto index = rewriter.create(op.getLoc(), i); rewriter.create(op.getLoc(), ireeList, index, operands[i]); } return success(); } }; } // namespace //===----------------------------------------------------------------------===// // The pass //===----------------------------------------------------------------------===// namespace { class ConvertTorchToIREE : public ConvertTorchToIREEBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); patterns.add(typeConverter, context); target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createConvertTorchToIREEPass() { return std::make_unique(); }