//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "utils/hlo_utils.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_mhlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, size_t dimSizeIndexBits) { auto selfTy = self.getType().template dyn_cast(); auto otherTy = other.getType().template dyn_cast(); auto selfRank = selfTy.getRank(); auto otherRank = otherTy.getRank(); if (selfRank == 0 || otherRank == 0) return success(); if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; } return success(); } bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); int64_t intValue; auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (elementType.isa()) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/false)); return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { auto integerType = elementType.cast(); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMaxValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } static FailureOr getMinValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (elementType.isa()) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/true)); return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { auto integerType = elementType.cast(); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMinValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only Tensor types supported in MHLO"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization supported"); } } }; } // namespace // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { template class ConvertAtenConstPatternOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); if (!outType) return op.emitError("only Tensor types supported in MHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) return op.emitError( "only floating-point or integer datatype legalization supported"); SmallVector shape; if (!matchPattern(op.size(), m_TorchListOfConstantInts(shape))) { return op.emitError("shape must be a list of Scalar constants"); } int64_t size = 1; for (auto s : shape) size *= s; SmallVector values(size, fillVal); auto constOp = mhlo::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; } // namespace // The binary broadcast patterns namespace { template class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("input data types mismatched"); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), lhs, rhs, /*broadcast_attr*/ nullptr); return success(); } }; } // namespace // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); RankedTensorType lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in MHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); if (isa(op)) { std::swap(lhs, rhs); } } lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); if (!skipMultiplyAlpha(op.alpha())) { Value alpha = mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // Binary op legalizations for Mul/Div variants. namespace { template class ConvertAtenMulDivOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in MHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (std::is_same()) { rhs = lhs; } else if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); } DenseIntElementsAttr bcastDimensions; lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); if (!isa(op)) { rewriter.replaceOp(op, result); return success(); } AtenDivTensorModeOp divTensorModeOp = llvm::dyn_cast(op.getOperation()); std::string roundingMode; if (!matchPattern(divTensorModeOp.rounding_mode(), m_TorchConstantStr(roundingMode))) return rewriter.notifyMatchFailure( op, "only support constant str rounding mode"); if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); auto abs = rewriter.create(loc, result); auto floor = rewriter.create(loc, abs); result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) result = rewriter.create(loc, result).getResult(); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // Binary op legalizations for comparator ops. namespace { template class ConvertAtenCompareOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); Value rhs = adaptor.other(); RankedTensorType lhsTy = lhs.getType().dyn_cast(); RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) return op.emitError("only Tensor types supported in MHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsTy) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy); } // TODO: what is the PyTorch default type promotion? rhs = mhlo::promoteType(rewriter, rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; if (lhsElemTy.isa()) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::FLOAT); } else if (lhsElemTy.isa()) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::SIGNED); } if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::EQ); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::NE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LE); } else { return op.emitError("operator haven't been supported"); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); return success(); } }; } // namespace // AtenTransposeIntOp namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); int64_t dim0; if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) { return rewriter.notifyMatchFailure(op, "dim0 must be constant"); } int64_t dim1; if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) { return rewriter.notifyMatchFailure(op, "dim1 must be constant"); } auto inType = self.getType().cast(); auto inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) { return rewriter.notifyMatchFailure(op, "dim0 out of range"); } dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) { return rewriter.notifyMatchFailure(op, "dim1 out of range"); } SmallVector permValues(inputRank); std::iota(std::begin(permValues), std::end(permValues), 0); std::swap(permValues[dim0], permValues[dim1]); DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } }; } // namespace // AtenToDtypeOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenToDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) return op.emitError("only tensor types are currently supported"); Value dim; int64_t dimInt; if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, selfType.getRank()); dim = rewriter.create(op.getLoc(), dimInt); } else { Value inputRank = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.dim(), inputRank); dim = rewriter.create(op.getLoc(), rewriter.getIndexType(), dim); } auto dimSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), dimSize); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenWhereSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const { Value self = adaptor.self(); Value cond = adaptor.condition(); Value other = adaptor.other(); if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast self and condition ranks"); if (failed( broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), ArrayRef{cond, self, other}); return success(); } // AtenBroadcastToOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); if (options.enableStaticShape && selfTy.hasStaticShape()) { Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } SmallVector shape; if (!(getListConstructElements(adaptor.size(), shape))) { return op->emitError("desired shape must be a list of scalar"); } SmallVector bcastShapeVec; int64_t totalRank = shape.size(); int64_t selfRank = selfTy.getRank(); int64_t leadingRank = totalRank - selfRank; for (int64_t i = 0; i < totalRank; ++i) { Value dValue = shape[i]; Value newD; int64_t dInt; if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { dValue = rewriter.create(op->getLoc(), dValue); newD = rewriter.create( op->getLoc(), rewriter.getIndexType(), dValue); } bcastShapeVec.push_back(newD); } if (options.dimSizeIndexBits == 32) { for (auto &dsize : bcastShapeVec) { auto dsizeI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), dsize); dsize = rewriter.create(op->getLoc(), rewriter.getI32Type(), dsizeI64); } } if (bcastShapeVec.size() == 0) { rewriter.replaceOpWithNewOp(op, outType, self); } else { Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); } return success(); } // AtenPermuteOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); // Not a ranked tensor type auto inType = self.getType().dyn_cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); if (!inType) return op.emitError("only ranked tensor types with static shapes are " "currently supported"); SmallVector permValues; if (!matchPattern(adaptor.dims(), m_TorchListOfConstantInts(permValues))) return rewriter.notifyMatchFailure( op, "only constant dimensions are currently supported"); int64_t inRank = inType.getRank(); for (auto &d : permValues) { d = toPositiveDim(d, inRank); if (!isValidDim(d, inRank)) return op.emitError("not all dims are valid"); } DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } // AtenTanhOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization currently supported"); } } // ValueTensorLiteralOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? if (auto elements = op.valueAttr().dyn_cast()) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); DenseElementsAttr valueAttr = elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); rewriter.replaceOpWithNewOp(op, resultType, valueAttr); return success(); } rewriter.replaceOpWithNewOp(op, resultType, adaptor.value()); return success(); } // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().cast(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (!inputTy.getElementType().isa()) { return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType outputType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); auto outputElemType = outputType.getElementType(); Value mhloTensor = mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType); rewriter.replaceOp(op, mhloTensor); return success(); } // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) return op.emitError("only tensor types are currently supported"); // FIXME: memory_format is not handled. rewriter.replaceOp(op, adaptor.self()); return success(); } // AtenReluOp // Relu(x) = Max(0, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isa()) { return op->emitError("only float tensor in relu op is supported"); } Value zeroTensor; zeroTensor = chlo::getConstantLike( rewriter, op->getLoc(), APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.self(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); auto rsqrtTwo = rewriter.create(loc, two); auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); auto erfAdd = rewriter.create(loc, erf, one); auto halfMul = rewriter.create(loc, erfAdd, half); rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputType = input.getType().cast(); if (!inputType.getElementType().isa()) { return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), input); return success(); } // AtenBatchNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.input(); // shape = [N, C, H, W] auto inputTy = input.getType().cast(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value runningMean = adaptor.running_mean(); Value runningVar = adaptor.running_var(); // momentum is ignored Value momentum = adaptor.momentum(); (void)momentum; if (inputTy.getRank() <= 2) { return rewriter.notifyMatchFailure(op, "input should have rank larger than 2"); } if (!inputTy.getElementType().template isa()) { return op.emitError("only input tensor of float type is supported"); } auto inputElemTy = inputTy.getElementType().cast(); Value channelDim = rewriter.create(op->getLoc(), input, 1); if (options.dimSizeIndexBits == 32) { auto channelDimI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), channelDim); channelDim = rewriter.create( op->getLoc(), rewriter.getI32Type(), channelDimI64); } Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { weight = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { bias = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { runningVar = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { runningMean = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); auto runningMeanTy = runningMean.getType().cast(); auto runningVarTy = runningVar.getType().cast(); if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } if (!weightTy.getElementType().template isa() || !biasTy.getElementType().template isa() || !runningMeanTy.getElementType().template isa() || !runningVarTy.getElementType().template isa()) { return op.emitError("only float weight/bias/runningMean/runningVar tensor " "of float type is supported"); } double eps = 0.0; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); } bool training = false; if (!matchPattern(op.training(), m_TorchConstantBool(&training))) { return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); } // TODO: handle cudnnEnabled parameter. Here, we just ignore it! bool cudnnEnabled = false; if (!matchPattern(op.cudnn_enabled(), m_TorchConstantBool(&cudnnEnabled))) { return rewriter.notifyMatchFailure(op, "non-bool cudnn_enabled unsupported"); } if (training) { Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); SmallVector castShape{inputTy.getShape().begin(), inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); // Feature counts must match among operands of mhlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); Value output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } } // AtenNativeLayerNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.input(); auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); if (!inputTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } SmallVector normalizedShape; if (!matchPattern(op.normalized_shape(), m_TorchListOfConstantInts(normalizedShape))) { return rewriter.notifyMatchFailure( op, "normalized_shape must be a list of const int"); } double eps = 0; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non const float eps is unsupported"); } if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) { return op->emitError("none weight or bias is unsupported"); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); if (!inputTy.getElementType().isa() || !biasTy.getElementType().isa() || !weightTy.getElementType().isa()) { return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); if (weightTy.getRank() != normalizedShapeRank || biasTy.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) { return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" "normalized shape not compatible"); } for (int64_t i = 1; i <= normalizedShapeRank; i++) { if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || weightTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i] || biasTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i]) { return op.emitError("mismatching contracting dimension"); } } // Flatten dims to fit batch_norm operation. int64_t numFeatureDimSize = 1; int64_t numEmbeddingDimSize = 1; for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { numFeatureDimSize *= inputShape[i]; } for (int64_t i = 0; i < normalizedShapeRank; i++) { numEmbeddingDimSize *= normalizedShape[i]; } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; auto mhloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); auto mhloBathNormOutMeanOrVarTy = RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); // Reshape input auto mhloInput = rewriter.create( op->getLoc(), mhloBatchNormOutTy, input, mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .value()); // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() .getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( inputTy.getElementType().cast().getFloatSemantics(), 1)); auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); auto batchNormTrainingResult = rewriter.create( op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outputMeanOrVarTy = getTypeConverter()->convertType(op.getType(1)).cast(); auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), mhlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) .value()); auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); // Apply affine transform: output x weight + bias [element-wise] auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } // AtenCatOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "only constant dim param is supported"); } SmallVector torchTensors; if (!getListConstructElements(op.tensors(), torchTensors)) { return rewriter.notifyMatchFailure( op, "input should comes from a PrimListConstructOp"); } SmallVector builtinTensors = getTypeConvertedValues( rewriter, op->getLoc(), getTypeConverter(), torchTensors); // Promote type for (auto &v : builtinTensors) { v = mhlo::promoteType(rewriter, v, outType); } size_t posDim = toPositiveDim(dim, outType.getRank()); rewriter.replaceOpWithNewOp( op, outType, ValueRange(builtinTensors), posDim); return success(); } // AtenNumelOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().dyn_cast(); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (size_t d = 0; d < rank; ++d) { Value dimSize = rewriter.create( loc, intType, rewriter.create(loc, self, d)); numel = rewriter.create(loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); if (outTy != numel.getType()) { rewriter.replaceOpWithNewOp(op, outTy, numel); } else { rewriter.replaceOp(op, numel); } return success(); } // AtenClampOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputType = input.getType().cast(); auto inputElemType = inputType.getElementType(); Value minValue = adaptor.min(); Value maxValue = adaptor.max(); if (failed(checkNotNone(rewriter, op, minValue)) && failed(checkNotNone(rewriter, op, maxValue))) { return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(checkNotNone(rewriter, op, minValue))) { maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate min value of dtype"); } minValue = *minInfo; } else if (failed(checkNotNone(rewriter, op, maxValue))) { minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate max value of dtype"); } maxValue = *maxInfo; } else { minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); } // AtenArangeStartStepOp // aten.arange.start_step = range(ceil((end-start)/step)) * step + start. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); // Get element type of resultType as dtype auto outType = this->getTypeConverter() ->convertType(op.getType()) .cast(); auto dtype = outType.getElementType(); if (!dtype.isa() && !dtype.isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: only int or float dtype supported"); } Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.start(), dtype); Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.end(), dtype); Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.step(), dtype); // Get length of the 1-d output tensor Value subOut = rewriter.create(loc, end, start); Value divOut = rewriter.create(loc, subOut, step); Value resultLength = rewriter.create( loc, RankedTensorType::get({1}, dtype), divOut); if (dtype.isa()) { resultLength = rewriter.create(loc, resultLength); resultLength = rewriter.create( loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); } Value window = rewriter.create(loc, outType, resultLength, 0); DenseIntElementsAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, broadcastDimensions); return success(); } void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp); #undef INSERT_BINARY_MULDIV_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenNumelOp); INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp); #undef INSERT_BINARY_BROADCAST_PATTERN }