//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#torch operators), which is in the 100s. namespace { // Note: Confusingly, ATen's "dim" means "number of dimensions" which is what // MLIR calls "rank". class ConvertAtenDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rank = rewriter.create(op->getLoc(), adaptor.self()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); } }; } // namespace namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.condition(), adaptor.message()); return success(); } }; } // namespace namespace { template class ConvertAtenBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.template replaceOpWithNewOp(op, adaptor.a(), adaptor.b()); return success(); } }; } // namespace namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = adaptor.a(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); if (!input.getType().isa()) input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type()); Value result = rewriter.create(loc, input); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, resultType)); return success(); } }; } // namespace namespace { // Lowers aten integer comparison ops. template class ConvertAtenIntComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, Pred, adaptor.a(), adaptor.b()); return success(); } }; } // namespace namespace { // Lowers aten float and float_int comparison ops. template class ConvertAtenFloatComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.a(), rhs = adaptor.b(); rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType()); rewriter.replaceOpWithNewOp(op, Pred, lhs, rhs); return success(); } }; } // namespace // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. namespace { class ConvertTorchTensorLiteralOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = ValueTensorLiteralOp::Adaptor; LogicalResult matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); if (auto elements = op.valueAttr().dyn_cast()) { Type elemTy = op.valueAttr().getElementType(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); rewriter.replaceOpWithNewOp( op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); })); return success(); } if (auto elements = op.valueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = IntegerType::get(context, intType.getIntOrFloatBitWidth()); auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); rewriter.replaceOpWithNewOp( op, OpaqueElementsAttr::get(elements.getDialect(), shapedType, elements.getValue())); return success(); } } } rewriter.replaceOpWithNewOp(op, op.valueAttr()); return success(); } }; } // namespace namespace { template class ConvertTorchConstantOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.valueAttr()); return success(); } }; } // namespace namespace { class ConvertAtenAnyBoolOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenAnyBoolOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector inputListTorchBool; if (!getListConstructElements(op.self(), inputListTorchBool)) { return rewriter.notifyMatchFailure( op, "Unimplemented input list not constructed from ListConstruct"); } SmallVector inputListBool; for (Value v : inputListTorchBool) { bool cst; if (!matchPattern(v, m_TorchConstantBool(&cst))) return rewriter.notifyMatchFailure( op, "only support constant bool input list elements"); inputListBool.push_back(cst); } bool result = llvm::any_of( inputListBool, [](bool inputListElem) { return inputListElem; }); rewriter.replaceOpWithNewOp( op, rewriter.getBoolAttr(result)); return success(); } }; } // namespace namespace { template class ConvertAtenBoolLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type inputType = adaptor.a().getType(); Value cstZero = rewriter.create( loc, rewriter.getZeroAttr(inputType)); Value cstTrue = rewriter.create(loc, rewriter.getBoolAttr(true)); Value cstFalse = rewriter.create(loc, rewriter.getBoolAttr(false)); Value cmpPred; cmpPred = rewriter.create(loc, Pred, adaptor.a(), cstZero); rewriter.replaceOpWithNewOp(op, cmpPred, cstTrue, cstFalse); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToStd : public ConvertTorchToStdBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns .add>( typeConverter, context); patterns .add>( typeConverter, context); patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns .add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenBoolLikeOp>( typeConverter, context); patterns.add< ConvertAtenBoolLikeOp>( typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToStdPass() { return std::make_unique(); }