//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, size_t dimSizeIndexBits) { auto selfTy = dyn_cast(self.getType()); auto otherTy = dyn_cast(other.getType()); auto selfRank = selfTy.getRank(); auto otherRank = otherTy.getRank(); if (selfRank == 0 || otherRank == 0) return success(); if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; } return success(); } bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); int64_t intValue; auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (isa(elementType)) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/false)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } if (isa(elementType)) { auto integerType = cast(elementType); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMaxValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } static FailureOr getMinValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (isa(elementType)) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/true)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } if (isa(elementType)) { auto integerType = cast(elementType); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMinValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } // These legalizations are for unary ops. namespace { template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); auto selfType = cast(self.getType()); if (!selfType) { return op.emitError("only Tensor types supported in StableHLO"); } auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); self = hlo::promoteType(rewriter, op.getLoc(), self, outType); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } }; } // namespace // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); if (isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization supported"); } } }; } // namespace // These legalizations are for unary ops with promoting to floating point // datatypes. namespace { template class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); auto resultTy = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); if (isa(resultTy.getElementType())) { Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); rewriter.replaceOpWithNewOp(op, resultTy, src); return success(); } else { return op.emitError( "only result to be floating-point datatype legalization supported"); } } }; } // namespace // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { template class ConvertAtenConstPatternOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); if (!outType) return op.emitError("only Tensor types supported in StableHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) return op.emitError( "only floating-point or integer datatype legalization supported"); SmallVector shape; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { return op.emitError("shape must be a list of Scalar constants"); } int64_t size = 1; for (auto s : shape) size *= s; SmallVector values(size, fillVal); auto constOp = hlo::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; } // namespace namespace { // Casts a tensor of exactly one element to an elemental type. // Many codes borrowed from // `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp` template class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto inputType = dyn_cast(adaptor.getA().getType()); if (!inputType) op.emitError("only Tensor types supported in StableHLO"); Location loc = op.getLoc(); Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); Type inputDtype = cast(op.getA().getType()).getDtype(); Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); // handle unsigned interger if (inputType.getElementType().isUnsignedInteger()) { input = rewriter.create( loc, input, rewriter.getIntegerType( inputType.getElementType().getIntOrFloatBitWidth())); } Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); Value result = rewriter.create(loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype, /*srcOriginalDtype=*/inputType.getElementType())); return success(); } }; } // namespace // The binary broadcast patterns namespace { template class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); auto lhsTy = cast(lhs.getType()); Value rhs = adaptor.getOther(); auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto outTy = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); return success(); } }; } // namespace // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); RankedTensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); TensorType outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); if (isa(op)) { std::swap(lhs, rhs); } } lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); DenseI64ArrayAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // Binary op legalizations for Mul/Div variants. namespace { template class ConvertAtenMulDivOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); TensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if constexpr (std::is_same()) { rhs = lhs; } else { if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); } } DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); if constexpr (!std::is_same() && !std::is_same()) { rewriter.replaceOp(op, result); return success(); } auto tensorOp = dyn_cast(op.getOperation()); auto opRoundingMode = tensorOp ? tensorOp.getRoundingMode() : cast(op.getOperation()).getRoundingMode(); std::string roundingMode; if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) { return rewriter.notifyMatchFailure( op, "only support constant str rounding mode"); } // if trunc and int, do nothing if (roundingMode == "trunc" && isa(outElemTy)) { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); auto abs = rewriter.create(loc, result); auto floor = rewriter.create(loc, abs); result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) if (isa(outElemTy)) result = rewriter.create(loc, result).getResult(); else if (!outElemTy.isUnsignedInteger()) { TensorType defaultIntToFloatType = outType.cloneWith(outType.getShape(), rewriter.getF64Type()); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType); result = rewriter.create(loc, defaultIntToFloatType, lhs, rhs, bcastDimensions); result = rewriter.create(loc, result).getResult(); result = hlo::promoteType(rewriter, op.getLoc(), result, outType); } } rewriter.replaceOp(op, result); return success(); } }; } // namespace // Binary op legalizations for comparator ops. namespace { template class ConvertAtenCompareOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); RankedTensorType lhsTy = dyn_cast(lhs.getType()); RankedTensorType rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) { return op.emitError("only Tensor types supported in StableHLO"); } if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); // use lhs's element type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhsTy = dyn_cast(rhs.getType()); } auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Type lhsElemTy = lhsTy.getElementType(); Type rhsElemTy = rhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat() || !rhsElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (isa(lhsElemTy) && isa(rhsElemTy)) { lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); } else if (isa(lhsElemTy) && isa(rhsElemTy)) { rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); } else { if (lhsElemTy.getIntOrFloatBitWidth() > rhsElemTy.getIntOrFloatBitWidth()) { rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); } else { lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); } } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; if (isa(lhsElemTy)) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::FLOAT); } else if (isa(lhsElemTy)) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::SIGNED); } if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GT); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GE); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::EQ); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::NE); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if constexpr (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LE); } else { return op.emitError("operator haven't been supported"); } DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); return success(); } }; } // namespace // Binary op legalizations for Logical And/Or/Xor. namespace { template class ConvertAtenLogicalBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); RankedTensorType lhsTy = dyn_cast(lhs.getType()); RankedTensorType rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) return op.emitError("lhs must be a ranked tensor type"); TensorType outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Type outElemTy = outType.getElementType(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // AtenTransposeIntOp namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); int64_t dim0; if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) { return rewriter.notifyMatchFailure(op, "dim0 must be constant"); } int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) { return rewriter.notifyMatchFailure(op, "dim1 must be constant"); } auto inType = cast(self.getType()); auto inputRank = inType.getRank(); auto outType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) { return rewriter.notifyMatchFailure(op, "dim0 out of range"); } dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) { return rewriter.notifyMatchFailure(op, "dim1 out of range"); } SmallVector permValues(inputRank); std::iota(std::begin(permValues), std::end(permValues), 0); std::swap(permValues[dim0], permValues[dim1]); rewriter.replaceOpWithNewOp(op, outType, self, permValues); return success(); } }; } // namespace // AtenToDtypeOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenToDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = cast(getTypeConverter()->convertType(op.getType())); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return op.emitError("only tensor types are currently supported"); Value dim; int64_t dimInt; if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, selfType.getRank()); if (!isValidDim(dimInt, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); dim = rewriter.create(op.getLoc(), dimInt); } else { Value inputRank = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); dim = rewriter.create(op.getLoc(), rewriter.getIndexType(), dim); } auto dimSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), adaptor.getSelf(), dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), dimSize); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenWhereSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value cond = adaptor.getCondition(); Value other = adaptor.getOther(); auto outType = cast(getTypeConverter()->convertType(op.getType())); // promote self and other types self = hlo::promoteType(rewriter, op.getLoc(), self, outType); other = hlo::promoteType(rewriter, op.getLoc(), other, outType); if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast self and condition ranks"); if (failed( broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), ArrayRef{cond, self, other}); return success(); } // AtenBroadcastToOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); auto outType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); if (options.enableStaticShape && selfTy.hasStaticShape()) { Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } SmallVector shape; if (!(getListConstructElements(adaptor.getSize(), shape))) { return op->emitError("desired shape must be a list of scalar"); } SmallVector bcastShapeVec; int64_t totalRank = shape.size(); int64_t selfRank = selfTy.getRank(); int64_t leadingRank = totalRank - selfRank; for (int64_t i = 0; i < totalRank; ++i) { Value dValue = shape[i]; Value newD; int64_t dInt; if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { dValue = rewriter.create(op->getLoc(), dValue); newD = rewriter.create( op->getLoc(), rewriter.getIndexType(), dValue); } bcastShapeVec.push_back(newD); } if (options.dimSizeIndexBits == 32) { for (auto &dsize : bcastShapeVec) { auto dsizeI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), dsize); dsize = rewriter.create(op->getLoc(), rewriter.getI32Type(), dsizeI64); } } if (bcastShapeVec.size() == 0) { rewriter.replaceOpWithNewOp(op, outType, self); } else { Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getDenseI64ArrayAttr(dimensionNumbers)); } return success(); } // AtenPermuteOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); // Not a ranked tensor type auto inType = dyn_cast(self.getType()); auto outType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); if (!inType) return op.emitError("only ranked tensor types with static shapes are " "currently supported"); SmallVector permValues; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(permValues))) return rewriter.notifyMatchFailure( op, "only constant dimensions are currently supported"); int64_t inRank = inType.getRank(); for (auto &d : permValues) { d = toPositiveDim(d, inRank); if (!isValidDim(d, inRank)) return op.emitError("not all dims are valid"); } rewriter.replaceOpWithNewOp(op, outType, self, permValues); return success(); } // ValueTensorLiteralOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? if (auto elements = dyn_cast(op.getValueAttr())) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); DenseElementsAttr valueAttr = elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); rewriter.replaceOpWithNewOp(op, resultType, valueAttr); return success(); } rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue()); return success(); } // AtenTensorIntOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTensorIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value innerValue = adaptor.getT(); Value stablehloTensor = hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType); rewriter.replaceOp(op, stablehloTensor); return success(); } // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = cast(input.getType()); auto outTy = cast(getTypeConverter()->convertType(op.getType())); if (!isa(inputTy.getElementType())) { return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } Value oneTensor = hlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } // AtenPowTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getExponent(); TensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); auto outType = cast( OpConversionPattern::getTypeConverter() ->convertType(op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); rewriter.replaceOp(op, result); return success(); } // AtenPowScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getExponent(); auto rhsType = dyn_cast(rhs.getType()); if (!rhsType) return op.emitError("only Tensor types supported in StableHLO"); auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!lhsType) { lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); } DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); rewriter.replaceOp(op, result); return success(); } // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType outputType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); auto outputElemType = outputType.getElementType(); Value stablehloTensor = hlo::scalarToStablehloTensor( rewriter, op, adaptor.getA(), outputElemType); rewriter.replaceOp(op, stablehloTensor); return success(); } // AtenScalarImplicitOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenScalarImplicitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Type inputDtype = cast(op.getA().getType()).getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); auto result = rewriter.create(loc, adaptor.getA()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); return success(); } // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return op.emitError("only tensor types are currently supported"); // FIXME: memory_format is not handled. rewriter.replaceOp(op, adaptor.getSelf()); return success(); } // AtenReluOp // Relu(x) = Max(0, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsTy = cast(lhs.getType()); auto lhsElemTy = lhsTy.getElementType(); if (!isa(lhsElemTy)) { return op->emitError("only float tensor in relu op is supported"); } Value zeroTensor = hlo::getConstantLike(rewriter, op->getLoc(), 0, lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } // Convert a Aten::GELU to HLO // Gelu(x, "none") = x * 0.5 * (1 + erf(x/(sqrt(2)))) // Gelu(x, "tanh") = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } std::string approximate; if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) { return op.emitError("approximate must be constant string"); } if (approximate != "none" && approximate != "tanh") { return op.emitError("unsupported approximate: ") << approximate; } Value one = hlo::getConstantLike(rewriter, loc, 1.0, input); Value two = hlo::getConstantLike(rewriter, loc, 2.0, input); Value three = hlo::getConstantLike(rewriter, loc, 3.0, input); Value half = hlo::getConstantLike(rewriter, loc, 0.5, input); // 2/pi Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input); Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input); // x * 0.5 auto inputMulHalf = rewriter.create(loc, input, half); if (approximate == "none") { auto rsqrtTwo = rewriter.create(loc, two); auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); auto erfAdd = rewriter.create(loc, erf, one); rewriter.replaceOpWithNewOp(op, erfAdd, inputMulHalf); return success(); } else { auto sqrtTwoPi = rewriter.create(loc, twoDivPi); // x^3 auto powThree = rewriter.create(loc, input, three); // x + 0.044715 * x^3 auto add = rewriter.create( loc, input, rewriter.create(loc, t, powThree)); auto tanh = rewriter.create( loc, rewriter.create(loc, sqrtTwoPi, add)); auto tanhAdd = rewriter.create(loc, tanh, one); rewriter.replaceOpWithNewOp(op, tanhAdd, inputMulHalf); return success(); } } // AtenLog2Op template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } auto outTy = cast(getTypeConverter()->convertType(op.getType())); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input); auto log2Op = rewriter.create(op.getLoc(), two); auto logInputOp = rewriter.create(op.getLoc(), input); rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log2Op); return success(); } // AtenLog10Op template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog10Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } auto outTy = cast(getTypeConverter()->convertType(op.getType())); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input); auto log10Op = rewriter.create(op.getLoc(), ten); auto logInputOp = rewriter.create(op.getLoc(), input); rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log10Op); return success(); } // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputType = cast(input.getType()); if (!isa(inputType.getElementType())) { return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), input); return success(); } // AtenBatchNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); auto inputTy = cast(input.getType()); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); Value runningMean = adaptor.getRunningMean(); Value runningVar = adaptor.getRunningVar(); // momentum is ignored Value momentum = adaptor.getMomentum(); (void)momentum; // handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d, // all of NC, NCL, NCHW, NCDHW's feature index is 1. int64_t feature_index = 1; if (!isa(inputTy.getElementType())) { return op.emitError("only input tensor of float type is supported"); } auto inputElemTy = cast(inputTy.getElementType()); Value channelDim = rewriter.create(op->getLoc(), input, feature_index); if (options.dimSizeIndexBits == 32) { auto channelDimI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), channelDim); channelDim = rewriter.create( op->getLoc(), rewriter.getI32Type(), channelDimI64); } Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { weight = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { bias = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { runningVar = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { runningMean = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } auto weightTy = cast(weight.getType()); auto biasTy = cast(bias.getType()); auto runningMeanTy = cast(runningMean.getType()); auto runningVarTy = cast(runningVar.getType()); if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } if (!isa(weightTy.getElementType()) || !isa(biasTy.getElementType()) || !isa(runningMeanTy.getElementType()) || !isa(runningVarTy.getElementType())) { return op.emitError("only float weight/bias/runningMean/runningVar tensor " "of float type is supported"); } double eps = 0.0; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); } bool training = false; if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); } // TODO: handle cudnnEnabled parameter. Here, we just ignore it! bool cudnnEnabled = false; if (!matchPattern(op.getCudnnEnabled(), m_TorchConstantBool(&cudnnEnabled))) { return rewriter.notifyMatchFailure(op, "non-bool cudnn_enabled unsupported"); } if (training) { Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; if (cast(weightTy.getElementType()).getWidth() > cast(inputTy.getElementType()).getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), batchNormTrainingResult.getResult(0), cast(outputTy)); } else { auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = batchNormTrainingResult.getResult(0); } rewriter.replaceOp(op, output); return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); SmallVector castShape{inputTy.getShape().begin(), inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); // Feature counts must match among operands of // stablehlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; if (cast(weightTy.getElementType()).getWidth() > cast(inputTy.getElementType()).getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); runningMean = hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); runningVar = hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); Value bnResult = rewriter.create( op.getLoc(), convertedType, input, weight, bias, runningMean, runningVar, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), bnResult, cast(outputTy)); } else { output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); } rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } } // AtenNativeLayerNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); if (!inputTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } SmallVector normalizedShape; if (!matchPattern(op.getNormalizedShape(), m_TorchListOfConstantInts(normalizedShape))) { return rewriter.notifyMatchFailure( op, "normalized_shape must be a list of const int"); } double eps = 0; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non const float eps is unsupported"); } if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) { return op->emitError("none weight or bias is unsupported"); } auto weightTy = cast(weight.getType()); auto biasTy = cast(bias.getType()); if (!isa(inputTy.getElementType()) || !isa(biasTy.getElementType()) || !isa(weightTy.getElementType())) { return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); if (weightTy.getRank() != normalizedShapeRank || biasTy.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) { return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" "normalized shape not compatible"); } for (int64_t i = 1; i <= normalizedShapeRank; i++) { if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || weightTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i] || biasTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i]) { return op.emitError("mismatching contracting dimension"); } } // Flatten dims to fit batch_norm operation. int64_t numFeatureDimSize = 1; int64_t numEmbeddingDimSize = 1; for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { numFeatureDimSize *= inputShape[i]; } for (int64_t i = 0; i < normalizedShapeRank; i++) { numEmbeddingDimSize *= normalizedShape[i]; } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; SmallVector meanOrVarStablehloOutShape{numFeatureDimSize}; auto stablehloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get( meanOrVarStablehloOutShape, inputTy.getElementType()); // Reshape input auto stablehloInput = rewriter.create( op->getLoc(), stablehloBatchNormOutTy, input, hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .value()); // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero( cast(inputTy.getElementType()).getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( cast(inputTy.getElementType()).getFloatSemantics(), 1)); auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); auto batchNormTrainingResult = rewriter.create( op->getLoc(), stablehloBatchNormOutTy, stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = cast(getTypeConverter()->convertType(op.getType(0))); auto outputMeanOrVarTy = cast(getTypeConverter()->convertType(op.getType(1))); auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), hlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) .value()); auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); // Apply affine transform: output x weight + bias [element-wise] auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create( op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } // AtenCatOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = cast(getTypeConverter()->convertType(op.getType())); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "only constant dim param is supported"); } dim = toPositiveDim(dim, outType.getRank()); if (!isValidDim(dim, outType.getRank())) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector torchTensors; if (!getListConstructElements(op.getTensors(), torchTensors)) { return rewriter.notifyMatchFailure( op, "input should comes from a PrimListConstructOp"); } SmallVector builtinTensors = getTypeConvertedValues( rewriter, op->getLoc(), getTypeConverter(), torchTensors); // Promote type for (auto &v : builtinTensors) { v = hlo::promoteType(rewriter, op->getLoc(), v, outType); } rewriter.replaceOpWithNewOp( op, outType, ValueRange(builtinTensors), dim); return success(); } // AtenNumelOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = dyn_cast(self.getType()); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (size_t d = 0; d < rank; ++d) { Value dimSize = rewriter.create( loc, intType, rewriter.create(loc, self, d)); numel = rewriter.create(loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); if (outTy != numel.getType()) { rewriter.replaceOpWithNewOp(op, outTy, numel); } else { rewriter.replaceOp(op, numel); } return success(); } // AtenClampOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputType = cast(input.getType()); auto inputElemType = inputType.getElementType(); Value minValue = adaptor.getMin(); Value maxValue = adaptor.getMax(); if (failed(checkNotNone(rewriter, op, minValue)) && failed(checkNotNone(rewriter, op, maxValue))) { return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(checkNotNone(rewriter, op, minValue))) { maxValue = hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate min value of dtype"); } minValue = *minInfo; } else if (failed(checkNotNone(rewriter, op, maxValue))) { minValue = hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate max value of dtype"); } maxValue = *maxInfo; } else { minValue = hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); maxValue = hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); } // AtenClampTensorOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputType = cast(input.getType()); auto inputElemType = inputType.getElementType(); Value minValue = adaptor.getMin(); Value maxValue = adaptor.getMax(); auto minIsNotNone = checkNotNone(rewriter, op, minValue); auto maxIsNotNone = checkNotNone(rewriter, op, maxValue); if (failed(minIsNotNone) && failed(maxIsNotNone)) { return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(minIsNotNone)) { auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate min value of dtype"); } minValue = *minInfo; } else if (failed(maxIsNotNone)) { auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate max value of dtype"); } maxValue = *maxInfo; } if (inputType.hasStaticShape()) { minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType); maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); } // AtenArangeStartStepOp // aten.arange.start_step = range(ceil((end-start)/step)) * step + start. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); // Get element type of resultType as dtype auto outType = cast( this->getTypeConverter()->convertType(op.getType())); auto dtype = outType.getElementType(); if (!isa(dtype) && !isa(dtype)) { return rewriter.notifyMatchFailure( op, "unimplemented: only int or float dtype supported"); } Value start = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype); Value end = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype); Value step = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype); // Get length of the 1-d output tensor Value subOut = rewriter.create(loc, end, start); // promote div to f64 Type divType = RankedTensorType::get({}, rewriter.getF64Type()); Value divOut = rewriter.create( loc, rewriter.create(loc, divType, subOut), rewriter.create(loc, divType, step)); // ceil to i64 Value resultLength = rewriter.create( loc, RankedTensorType::get({}, rewriter.getI64Type()), rewriter.create(loc, divOut)); resultLength = rewriter.create( loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); Value window = rewriter.create(loc, outType, resultLength, 0); DenseI64ArrayAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, broadcastDimensions); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenConstantPadNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); auto selfElemTy = selfTy.getElementType(); int64_t rank = selfTy.getRank(); SmallVector padInts; if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure(op, "only support constant int pad ranges"); uint64_t padRank = padInts.size() / 2; if (padRank * 2 != padInts.size()) return rewriter.notifyMatchFailure(op, "pad range size is not even"); if (rank < 0 || padRank > (uint64_t)rank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); // Initialize low/high paddings with 0 for all the dims. SmallVector lowPadding(/*Size=*/rank, /*Value=*/0); SmallVector highPadding(/*Size=*/rank, /*Value=*/0); // Add the requested padding - note op.pad() is highest dim first ordered // pairs of low,high. // Add the requested padding - note op.pad() is highest dim first ordered // pairs of low,high. for (uint64_t i = 0; i < padRank; ++i) { lowPadding[rank - i - 1] = padInts[i * 2]; highPadding[rank - i - 1] = padInts[i * 2 + 1]; } Value constantValue = hlo::scalarToStablehloTensor( rewriter, op, adaptor.getValue(), selfElemTy); SmallVector interiorPadding(rank, 0); rewriter.replaceOpWithNewOp( op, self, constantValue, lowPadding, highPadding, interiorPadding); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto outType = cast(this->getTypeConverter()->convertType(op.getType())); if (!outType) { return op.emitError("only tensor type is supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) || approximate != "none") { return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } // Create constant value Value kAlpha = hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input); Value cstAlpha0 = hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input); Value half = hlo::getConstantLike(rewriter, loc, .5, input); Value one = hlo::getConstantLike(rewriter, loc, 1.0, input); Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input); // Compute Value kBeta0 = rewriter.create(loc, outType, kAlpha, cstAlpha0); Value kBeta = rewriter.create(loc, outType, kBeta0, half); Value erfArg = rewriter.create(loc, outType, kAlpha, adaptor.getSelf()); Value erf = rewriter.create(loc, outType, erfArg); Value erfAdd = rewriter.create(loc, outType, erf, one); Value cdf = rewriter.create(loc, outType, erfAdd, half); Value inputSquared = rewriter.create( loc, outType, adaptor.getSelf(), adaptor.getSelf()); Value negHalfInputSquared = rewriter.create(loc, outType, inputSquared, negHalf); Value expRes = rewriter.create(loc, outType, negHalfInputSquared); Value pdf = rewriter.create(loc, outType, kBeta, expRes); Value pdfTimesInput = rewriter.create(loc, outType, pdf, adaptor.getSelf()); Value pdfTimesInputAddCdf = rewriter.create(loc, outType, pdfTimesInput, cdf); rewriter.replaceOpWithNewOp( op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsTy = cast(lhs.getType()); Value rhs = adaptor.getExponent(); auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto outTy = cast(this->getTypeConverter()->convertType(op.getType())); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); return success(); } // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // TODO: Add support pin_memory and memory_format features. // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Preserve) return rewriter.notifyMatchFailure( op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( op, "unimplemented: device must be a constant str"); } // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( op, "unimplemented: layout must be a constant"); else if (tensorLayout != torch_upstream::Layout::Strided) return rewriter.notifyMatchFailure( op, "unimplemented: layout is expected to be strided"); } Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( op, "unimplemented: size must be constructed using ListConstruct"); } resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, resultSizeTorchInt); for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = cast(typeConverter->convertType(op.getType())); Type resultElementType; if (isa(op.getDtype().getType())) { resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = torch_to_stablehlo::getBackendTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape. Value initTensor = rewriter.create( loc, getAsOpFoldResult(resultSizeIndex), resultElementType); rewriter.replaceOpWithNewOp(op, resultType, initTensor); return success(); } // RuntimeAssertOp namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool condition; if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) { return rewriter.notifyMatchFailure( op, "unimplemented: condition must be a constant"); } if (!condition) { return op->emitError("condition must be true"); } rewriter.eraseOp(op); return success(); } }; } // namespace // AtenFillScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFillScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = cast(getTypeConverter()->convertType(op.getType())); auto dtype = outType.getElementType(); Value scalarTensor = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); Value shapeTensor = rewriter.create(op->getLoc(), adaptor.getSelf()); Value bcastScalar = rewriter.create( op->getLoc(), outType, scalarTensor, shapeTensor, rewriter.getDenseI64ArrayAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } // AtenFlipOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFlipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = cast(getTypeConverter()->convertType(op.getType())); SmallVector dims; if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) { return rewriter.notifyMatchFailure(op, "dims must be a list of const int"); } for (unsigned i = 0, e = dims.size(); i < e; i++) { dims[i] = toPositiveDim(dims[i], outType.getRank()); if (!isValidDim(dims[i], outType.getRank())) { return rewriter.notifyMatchFailure(op, "dim is statically invalid"); } } rewriter.replaceOpWithNewOp(op, outType, self, dims); return success(); } // AtenRemainderTensorOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenRemainderTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); auto resultType = cast(getTypeConverter()->convertType(op.getType())); lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } // AtenFmodTensorOp // torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFmodTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op->getLoc(); Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); auto resultType = cast(getTypeConverter()->convertType(op.getType())); lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); stablehlo::MulOp mul; auto div = rewriter.create(loc, lhs, rhs); if (isa(resultType.getElementType())) { // rounding mode is trunc auto sign = rewriter.create(loc, div); auto abs = rewriter.create(loc, div); auto floor = rewriter.create(loc, abs); auto trunc = rewriter.create(loc, sign, floor); mul = rewriter.create(loc, trunc, rhs); } else { mul = rewriter.create(loc, div, rhs); } rewriter.replaceOpWithNewOp(op, lhs, mul); return success(); } // AtenBitwiseLeftShiftTensorOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); auto resultType = cast(getTypeConverter()->convertType(op.getType())); rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } // AtenBitwiseRightShiftTensorOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); auto resultType = cast(getTypeConverter()->convertType(op.getType())); rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTrilOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } ArrayRef selfShape = selfTy.getShape(); int64_t selfRank = selfTy.getRank(); auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64); auto iotaTy = RankedTensorType::get( {selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy); Value colIdxTensor = rewriter.create(loc, iotaTy, 1).getResult(); Value rowIdxTensor = rewriter.create(loc, iotaTy, 0).getResult(); Value diagonal = adaptor.getDiagonal(); Value diagonalTensor = rewriter.create(loc, diagonal).getResult(); auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1}); Value shiftedRowIdxTensor = rewriter.create( loc, rowIdxTensor, diagonalTensor, bcastDimensions); auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::LE); auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); auto cmpTy = iotaTy.clone(rewriter.getI1Type()); Value cmpRes = rewriter.create( loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr, cmpTypeAttr); auto resTy = cast(getTypeConverter()->convertType(op.getType())); auto bcastTy = resTy.clone(rewriter.getI1Type()); auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); Value bcastedCmpRes = rewriter.create( loc, bcastTy, cmpRes, bcastAttr); auto resElemTy = resTy.getElementType(); Value zeroTensor; if (isa(resElemTy)) { auto constAttr = SplatElementsAttr::get( resTy, llvm::APFloat::getZero( cast(resElemTy).getFloatSemantics(), false)); zeroTensor = rewriter.create(loc, resTy, constAttr); } else if (isa(resElemTy)) { auto constAttr = SplatElementsAttr::get( resTy, llvm::APInt::getZero(cast(resElemTy).getWidth())); zeroTensor = rewriter.create(loc, resTy, constAttr); } else { return op.emitError("element type is not float or integer"); } rewriter.replaceOpWithNewOp( op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor); return success(); } void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp); INSERT_UNARY_PATTERN(AtenExpm1Op, stablehlo::Expm1Op); #undef INSERT_UNARY_PATTERN #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context) INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp); #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp); #undef INSERT_TENSOR_TO_SCALAR_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp); #undef INSERT_BINARY_MULDIV_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp); INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp); #undef INSERT_BINARY_LOGICAL_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenClampTensorOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenNumelOp); INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); INSERT_ATENOP_PATTERN(AtenFmodTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp); INSERT_ATENOP_PATTERN(AtenTrilOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context) INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp); INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op); #undef INSERT_BINARY_BROADCAST_PATTERN }