//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Conversion/TorchToStd/TorchToStd.h" #include "../PassDetail.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Torch; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#torch operators), which is in the 100s. namespace { // Note: Confusingly, ATen's "dim" means "number of dimensions" which is what // MLIR calls "rank". class ConvertAtenDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rank = rewriter.create(op->getLoc(), operands[0]); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); } }; } // namespace namespace { class ConvertAtenNeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNeIntOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, CmpIPredicate::ne, operands[0], operands[1]); return success(); } }; } // namespace namespace { class ConvertAtenGtIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenGtIntOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, CmpIPredicate::sgt, operands[0], operands[1]); return success(); } }; } // namespace namespace { template class ConvertTorchConstantOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.valueAttr()); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToStd : public ConvertTorchToStdBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createConvertTorchToStdPass() { return std::make_unique(); }