//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#torch operators), which is in the 100s. namespace { // Note: Confusingly, ATen's "dim" means "number of dimensions" which is what // MLIR calls "rank". class ConvertAtenDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); } }; } // namespace namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), adaptor.getMessage()); return success(); } }; } // namespace namespace { template class ConvertAtenBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); return success(); } }; } // namespace namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = adaptor.getA(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); if (!input.getType().isa()) input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type()); Value result = rewriter.create(loc, input); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, resultType)); return success(); } }; } // namespace namespace { class ConvertAtenDivIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDivIntOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } }; } // namespace namespace { // Lowers aten integer comparison ops. template class ConvertAtenIntComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, Pred, adaptor.getA(), adaptor.getB()); return success(); } }; } // namespace namespace { // Lowers aten float and float_int comparison ops. template class ConvertAtenFloatComparisonOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getA(), rhs = adaptor.getB(); rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType()); rewriter.replaceOpWithNewOp(op, Pred, lhs, rhs); return success(); } }; } // namespace // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. namespace { class ConvertTorchTensorLiteralOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = ValueTensorLiteralOp::Adaptor; LogicalResult matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); if (auto elements = op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { Type elemTy = op.getValueAttr().getElementType(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); auto rawData = elements.getRawData(); DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(shapedType, rawData); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } } if (auto elements = op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = IntegerType::get(context, intType.getIntOrFloatBitWidth()); auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); rewriter.replaceOpWithNewOp( op, DenseResourceElementsAttr::get(shapedType, elements.getRawHandle())); return success(); } } } rewriter.replaceOpWithNewOp(op, op.getValueAttr()); return success(); } }; } // namespace namespace { template class ConvertTorchConstantOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getValueAttr()); return success(); } }; class ConvertTorchConstantIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = Torch::ConstantIntOp::Adaptor; LogicalResult matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // note: arith.constant only accept signless integer, so convert signed to // signless rewriter.replaceOpWithNewOp( op, rewriter.getIntegerAttr(rewriter.getI64Type(), op.getValueAttr().getValue())); return success(); } }; } // namespace namespace { class ConvertAtenFloatScalarOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenFloatScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); Value result = convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class ConvertAtenAddOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenAddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); Value operandA = convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType); Value operandB = convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType); if (resultType.isa()) { rewriter.replaceOpWithNewOp(op, operandA, operandB); } else if (resultType.isa()) { rewriter.replaceOpWithNewOp(op, operandA, operandB); } else { return rewriter.notifyMatchFailure( op, "unimplemented: only support integer or float result type"); } return success(); } }; } // namespace namespace { template class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; virtual bool reductionFunction(ArrayRef inputArray) const = 0; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value result; SmallVector inputListTorchBool; if (!getListConstructElements(op.getSelf(), inputListTorchBool)) { return rewriter.notifyMatchFailure( op, "unimplemented: input list not constructed from ListConstruct"); } SmallVector inputList = getTypeConvertedValues( rewriter, loc, this->getTypeConverter(), inputListTorchBool); result = inputList[0]; for (unsigned i = 1; i < inputList.size(); i++) result = rewriter.create(loc, result, inputList[i]); rewriter.replaceOp(op, result); return success(); } }; class ConvertAtenAnyOp : public ConvertAtenAnyOrAllBoolOp { using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; bool reductionFunction(ArrayRef inputArray) const override { return llvm::any_of(inputArray, [](bool inputListElem) { return inputListElem; }); } }; class ConvertAtenAllOp : public ConvertAtenAnyOrAllBoolOp { using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; bool reductionFunction(ArrayRef inputArray) const override { return llvm::all_of(inputArray, [](bool inputListElem) { return inputListElem; }); } }; } // namespace namespace { template class ConvertAtenBoolLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type inputType = adaptor.getA().getType(); Value cstZero = rewriter.create( loc, rewriter.getZeroAttr(inputType)); Value cstTrue = rewriter.create(loc, rewriter.getBoolAttr(true)); Value cstFalse = rewriter.create(loc, rewriter.getBoolAttr(false)); Value cmpPred; cmpPred = rewriter.create(loc, Pred, adaptor.getA(), cstZero); rewriter.replaceOpWithNewOp(op, cmpPred, cstTrue, cstFalse); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToArith : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns .add>( typeConverter, context); patterns .add>( typeConverter, context); patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns .add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenBoolLikeOp>( typeConverter, context); patterns.add< ConvertAtenBoolLikeOp>( typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToArithPass() { return std::make_unique(); }