//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#torch operators), which is in the 100s. namespace { // Note: Confusingly, ATen's "dim" means "number of dimensions" which is what // MLIR calls "rank". class ConvertAtenDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rank = rewriter.create(op->getLoc(), adaptor.self()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); } }; } // namespace namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.condition(), adaptor.message()); return success(); } }; } // namespace namespace { template class ConvertAtenBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.template replaceOpWithNewOp(op, adaptor.a(), adaptor.b()); return success(); } }; } // namespace namespace { template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); Value result = rewriter.create(op.getLoc(), adaptor.a()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType)); return success(); } }; } // namespace namespace { // Lowers aten integer comparison ops. template class ConvertAtenIntComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, Pred, adaptor.a(), adaptor.b()); return success(); } }; } // namespace namespace { // Lowers aten float and float_int comparison ops. template class ConvertAtenFloatComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.a(), rhs = adaptor.b(); rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType()); rewriter.replaceOpWithNewOp(op, Pred, lhs, rhs); return success(); } }; } // namespace // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. namespace { class ConvertTorchTensorLiteralOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = ValueTensorLiteralOp::Adaptor; LogicalResult matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); if (auto elements = op.valueAttr().dyn_cast()) { Type elemTy = op.valueAttr().getElementType(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); rewriter.replaceOpWithNewOp( op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); })); return success(); } if (auto elements = op.valueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = IntegerType::get(context, intType.getIntOrFloatBitWidth()); auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); rewriter.replaceOpWithNewOp( op, OpaqueElementsAttr::get(elements.getDialect(), shapedType, elements.getValue())); return success(); } } } rewriter.replaceOpWithNewOp(op, op.valueAttr()); return success(); } }; } // namespace namespace { template class ConvertTorchConstantOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.valueAttr()); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToStd : public ConvertTorchToStdBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns .add>( typeConverter, context); patterns .add>( typeConverter, context); patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToStdPass() { return std::make_unique(); }