//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "utils/hlo_utils.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, size_t dimSizeIndexBits) { auto selfTy = self.getType().template dyn_cast(); auto otherTy = other.getType().template dyn_cast(); auto selfRank = selfTy.getRank(); auto otherRank = otherTy.getRank(); if (selfRank == 0 || otherRank == 0) return success(); if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; } return success(); } bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); int64_t intValue; auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (elementType.isa()) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/false)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { auto integerType = elementType.cast(); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMaxValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } static FailureOr getMinValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); if (elementType.isa()) { auto constAttr = SplatElementsAttr::get( constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/true)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { auto integerType = elementType.cast(); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( constType, APInt::getMinValue(integerType.getWidth())); } else { constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); } // These legalizations are for unary ops. namespace { template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); auto selfType = self.getType().cast(); if (!selfType) { return op.emitError("only Tensor types supported in StableHLO"); } auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); self = hlo::promoteType(rewriter, op.getLoc(), self, outType); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } }; } // namespace // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization supported"); } } }; } // namespace // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { template class ConvertAtenConstPatternOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); if (!outType) return op.emitError("only Tensor types supported in StableHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) return op.emitError( "only floating-point or integer datatype legalization supported"); SmallVector shape; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) { return op.emitError("shape must be a list of Scalar constants"); } int64_t size = 1; for (auto s : shape) size *= s; SmallVector values(size, fillVal); auto constOp = hlo::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; } // namespace namespace { // Casts a tensor of exactly one element to an elemental type. // Many codes borrowed from // `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp` template class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto inputType = adaptor.getA().getType().template dyn_cast(); if (!inputType) op.emitError("only Tensor types supported in StableHLO"); Location loc = op.getLoc(); Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); Type inputDtype = op.getA().getType().template cast().getDtype(); Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); Value result = rewriter.create(loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); return success(); } }; } // namespace // The binary broadcast patterns namespace { template class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.getOther(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto outTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); return success(); } }; } // namespace // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); RankedTensorType lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.getOther(); RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); if (isa(op)) { std::swap(lhs, rhs); } } lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // Binary op legalizations for Mul/Div variants. namespace { template class ConvertAtenMulDivOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); auto lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.getOther(); TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (std::is_same()) { rhs = lhs; } else if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); } DenseIntElementsAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); if (!isa(op)) { rewriter.replaceOp(op, result); return success(); } AtenDivTensorModeOp divTensorModeOp = llvm::dyn_cast(op.getOperation()); std::string roundingMode; if (!matchPattern(divTensorModeOp.getRoundingMode(), m_TorchConstantStr(roundingMode))) return rewriter.notifyMatchFailure( op, "only support constant str rounding mode"); if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); auto abs = rewriter.create(loc, result); auto floor = rewriter.create(loc, abs); result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) result = rewriter.create(loc, result).getResult(); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // Binary op legalizations for comparator ops. namespace { template class ConvertAtenCompareOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); RankedTensorType lhsTy = lhs.getType().dyn_cast(); RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) return op.emitError("only Tensor types supported in StableHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), lhsElemTy); } // TODO: what is the PyTorch default type promotion? rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; if (lhsElemTy.isa()) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::FLOAT); } else if (lhsElemTy.isa()) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::SIGNED); } if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::EQ); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::NE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::LE); } else { return op.emitError("operator haven't been supported"); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); return success(); } }; } // namespace // Binary op legalizations for Logical And/Or/Xor. namespace { template class ConvertAtenLogicalBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Value lhs = hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType); Value rhs = hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // AtenTransposeIntOp namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); int64_t dim0; if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) { return rewriter.notifyMatchFailure(op, "dim0 must be constant"); } int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) { return rewriter.notifyMatchFailure(op, "dim1 must be constant"); } auto inType = self.getType().cast(); auto inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) { return rewriter.notifyMatchFailure(op, "dim0 out of range"); } dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) { return rewriter.notifyMatchFailure(op, "dim1 out of range"); } SmallVector permValues(inputRank); std::iota(std::begin(permValues), std::end(permValues), 0); std::swap(permValues[dim0], permValues[dim1]); DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } }; } // namespace // AtenToDtypeOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenToDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = adaptor.getSelf().getType().dyn_cast(); if (!selfType) return op.emitError("only tensor types are currently supported"); Value dim; int64_t dimInt; if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, selfType.getRank()); if (!isValidDim(dimInt, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); dim = rewriter.create(op.getLoc(), dimInt); } else { Value inputRank = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); dim = rewriter.create(op.getLoc(), rewriter.getIndexType(), dim); } auto dimSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), adaptor.getSelf(), dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), dimSize); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenWhereSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value cond = adaptor.getCondition(); Value other = adaptor.getOther(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); // promote self and other types self = hlo::promoteType(rewriter, op.getLoc(), self, outType); other = hlo::promoteType(rewriter, op.getLoc(), other, outType); if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast self and condition ranks"); if (failed( broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), ArrayRef{cond, self, other}); return success(); } // AtenBroadcastToOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); if (options.enableStaticShape && selfTy.hasStaticShape()) { Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } SmallVector shape; if (!(getListConstructElements(adaptor.getSize(), shape))) { return op->emitError("desired shape must be a list of scalar"); } SmallVector bcastShapeVec; int64_t totalRank = shape.size(); int64_t selfRank = selfTy.getRank(); int64_t leadingRank = totalRank - selfRank; for (int64_t i = 0; i < totalRank; ++i) { Value dValue = shape[i]; Value newD; int64_t dInt; if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { dValue = rewriter.create(op->getLoc(), dValue); newD = rewriter.create( op->getLoc(), rewriter.getIndexType(), dValue); } bcastShapeVec.push_back(newD); } if (options.dimSizeIndexBits == 32) { for (auto &dsize : bcastShapeVec) { auto dsizeI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), dsize); dsize = rewriter.create(op->getLoc(), rewriter.getI32Type(), dsizeI64); } } if (bcastShapeVec.size() == 0) { rewriter.replaceOpWithNewOp(op, outType, self); } else { Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); } return success(); } // AtenPermuteOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); // Not a ranked tensor type auto inType = self.getType().dyn_cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); if (!inType) return op.emitError("only ranked tensor types with static shapes are " "currently supported"); SmallVector permValues; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(permValues))) return rewriter.notifyMatchFailure( op, "only constant dimensions are currently supported"); int64_t inRank = inType.getRank(); for (auto &d : permValues) { d = toPositiveDim(d, inRank); if (!isValidDim(d, inRank)) return op.emitError("not all dims are valid"); } DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } // ValueTensorLiteralOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? if (auto elements = op.getValueAttr().dyn_cast()) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); DenseElementsAttr valueAttr = elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); rewriter.replaceOpWithNewOp(op, resultType, valueAttr); return success(); } rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue()); return success(); } // AtenTensorIntOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTensorIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type outElementType = resultType.getElementType(); Value innerValue = adaptor.getT(); Value stablehloTensor = hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType); rewriter.replaceOp(op, stablehloTensor); return success(); } // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (!inputTy.getElementType().isa()) { return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } // AtenPowTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.getExponent(); TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } DenseIntElementsAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); rewriter.replaceOp(op, result); return success(); } // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType outputType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); auto outputElemType = outputType.getElementType(); Value stablehloTensor = hlo::scalarToStablehloTensor( rewriter, op, adaptor.getA(), outputElemType); rewriter.replaceOp(op, stablehloTensor); return success(); } // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = adaptor.getSelf().getType().dyn_cast(); if (!selfType) return op.emitError("only tensor types are currently supported"); // FIXME: memory_format is not handled. rewriter.replaceOp(op, adaptor.getSelf()); return success(); } // AtenReluOp // Relu(x) = Max(0, x) template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsTy = lhs.getType().cast(); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isa()) { return op->emitError("only float tensor in relu op is supported"); } Value zeroTensor; zeroTensor = chlo::getConstantLike( rewriter, op->getLoc(), APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); auto rsqrtTwo = rewriter.create(loc, two); auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); auto erfAdd = rewriter.create(loc, erf, one); auto halfMul = rewriter.create(loc, erfAdd, half); rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); if (!inputType.getElementType().isa()) { return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), input); return success(); } // AtenBatchNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); auto inputTy = input.getType().cast(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); Value runningMean = adaptor.getRunningMean(); Value runningVar = adaptor.getRunningVar(); // momentum is ignored Value momentum = adaptor.getMomentum(); (void)momentum; // handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d, // all of NC, NCL, NCHW, NCDHW's feature index is 1. int64_t feature_index = 1; if (!inputTy.getElementType().template isa()) { return op.emitError("only input tensor of float type is supported"); } auto inputElemTy = inputTy.getElementType().cast(); Value channelDim = rewriter.create(op->getLoc(), input, feature_index); if (options.dimSizeIndexBits == 32) { auto channelDimI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), channelDim); channelDim = rewriter.create( op->getLoc(), rewriter.getI32Type(), channelDimI64); } Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { weight = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { bias = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { runningVar = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { runningMean = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); auto runningMeanTy = runningMean.getType().cast(); auto runningVarTy = runningVar.getType().cast(); if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } if (!weightTy.getElementType().template isa() || !biasTy.getElementType().template isa() || !runningMeanTy.getElementType().template isa() || !runningVarTy.getElementType().template isa()) { return op.emitError("only float weight/bias/runningMean/runningVar tensor " "of float type is supported"); } double eps = 0.0; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); } bool training = false; if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); } // TODO: handle cudnnEnabled parameter. Here, we just ignore it! bool cudnnEnabled = false; if (!matchPattern(op.getCudnnEnabled(), m_TorchConstantBool(&cudnnEnabled))) { return rewriter.notifyMatchFailure(op, "non-bool cudnn_enabled unsupported"); } if (training) { Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; if (weightTy.getElementType().cast().getWidth() > inputTy.getElementType().cast().getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), batchNormTrainingResult.getResult(0), outputTy.cast()); } else { auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = batchNormTrainingResult.getResult(0); } rewriter.replaceOp(op, output); return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); SmallVector castShape{inputTy.getShape().begin(), inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); // Feature counts must match among operands of // stablehlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; if (weightTy.getElementType().cast().getWidth() > inputTy.getElementType().cast().getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); runningMean = hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); runningVar = hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); Value bnResult = rewriter.create( op.getLoc(), convertedType, input, weight, bias, runningMean, runningVar, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), bnResult, outputTy.cast()); } else { output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); } rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } } // AtenNativeLayerNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); if (!inputTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } SmallVector normalizedShape; if (!matchPattern(op.getNormalizedShape(), m_TorchListOfConstantInts(normalizedShape))) { return rewriter.notifyMatchFailure( op, "normalized_shape must be a list of const int"); } double eps = 0; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non const float eps is unsupported"); } if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) { return op->emitError("none weight or bias is unsupported"); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); if (!inputTy.getElementType().isa() || !biasTy.getElementType().isa() || !weightTy.getElementType().isa()) { return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); if (weightTy.getRank() != normalizedShapeRank || biasTy.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) { return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" "normalized shape not compatible"); } for (int64_t i = 1; i <= normalizedShapeRank; i++) { if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || weightTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i] || biasTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i]) { return op.emitError("mismatching contracting dimension"); } } // Flatten dims to fit batch_norm operation. int64_t numFeatureDimSize = 1; int64_t numEmbeddingDimSize = 1; for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { numFeatureDimSize *= inputShape[i]; } for (int64_t i = 0; i < normalizedShapeRank; i++) { numEmbeddingDimSize *= normalizedShape[i]; } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; SmallVector meanOrVarStablehloOutShape{numFeatureDimSize}; auto stablehloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get( meanOrVarStablehloOutShape, inputTy.getElementType()); // Reshape input auto stablehloInput = rewriter.create( op->getLoc(), stablehloBatchNormOutTy, input, hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .value()); // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() .getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( inputTy.getElementType().cast().getFloatSemantics(), 1)); auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); auto batchNormTrainingResult = rewriter.create( op->getLoc(), stablehloBatchNormOutTy, stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outputMeanOrVarTy = getTypeConverter()->convertType(op.getType(1)).cast(); auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), hlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) .value()); auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); // Apply affine transform: output x weight + bias [element-wise] auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create( op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } // AtenCatOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "only constant dim param is supported"); } dim = toPositiveDim(dim, outType.getRank()); if (!isValidDim(dim, outType.getRank())) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector torchTensors; if (!getListConstructElements(op.getTensors(), torchTensors)) { return rewriter.notifyMatchFailure( op, "input should comes from a PrimListConstructOp"); } SmallVector builtinTensors = getTypeConvertedValues( rewriter, op->getLoc(), getTypeConverter(), torchTensors); // Promote type for (auto &v : builtinTensors) { v = hlo::promoteType(rewriter, op->getLoc(), v, outType); } rewriter.replaceOpWithNewOp( op, outType, ValueRange(builtinTensors), dim); return success(); } // AtenNumelOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().dyn_cast(); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (size_t d = 0; d < rank; ++d) { Value dimSize = rewriter.create( loc, intType, rewriter.create(loc, self, d)); numel = rewriter.create(loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); if (outTy != numel.getType()) { rewriter.replaceOpWithNewOp(op, outTy, numel); } else { rewriter.replaceOp(op, numel); } return success(); } // AtenClampOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); auto inputElemType = inputType.getElementType(); Value minValue = adaptor.getMin(); Value maxValue = adaptor.getMax(); if (failed(checkNotNone(rewriter, op, minValue)) && failed(checkNotNone(rewriter, op, maxValue))) { return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(checkNotNone(rewriter, op, minValue))) { maxValue = hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate min value of dtype"); } minValue = *minInfo; } else if (failed(checkNotNone(rewriter, op, maxValue))) { minValue = hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate max value of dtype"); } maxValue = *maxInfo; } else { minValue = hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); maxValue = hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); } // AtenArangeStartStepOp // aten.arange.start_step = range(ceil((end-start)/step)) * step + start. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); // Get element type of resultType as dtype auto outType = this->getTypeConverter() ->convertType(op.getType()) .cast(); auto dtype = outType.getElementType(); if (!dtype.isa() && !dtype.isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: only int or float dtype supported"); } Value start = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype); Value end = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype); Value step = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype); // Get length of the 1-d output tensor Value subOut = rewriter.create(loc, end, start); Value divOut = rewriter.create(loc, subOut, step); Value resultLength = rewriter.create( loc, RankedTensorType::get({1}, dtype), divOut); if (dtype.isa()) { resultLength = rewriter.create(loc, resultLength); resultLength = rewriter.create( loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); } Value window = rewriter.create(loc, outType, resultLength, 0); DenseIntElementsAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, broadcastDimensions); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto outType = this->getTypeConverter()->convertType(op.getType()).cast(); if (!outType) { return op.emitError("only tensor type is supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) || approximate != "none") { return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } // Create constant value Value kAlpha = chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input); Value cstAlpha0 = chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input); Value half = chlo::getConstantLike(rewriter, loc, .5, input); Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); // Compute Value kBeta0 = rewriter.create(loc, outType, kAlpha, cstAlpha0); Value kBeta = rewriter.create(loc, outType, kBeta0, half); Value erfArg = rewriter.create(loc, outType, kAlpha, adaptor.getSelf()); Value erf = rewriter.create(loc, outType, erfArg); Value erfAdd = rewriter.create(loc, outType, erf, one); Value cdf = rewriter.create(loc, outType, erfAdd, half); Value inputSquared = rewriter.create( loc, outType, adaptor.getSelf(), adaptor.getSelf()); Value negHalfInputSquared = rewriter.create(loc, outType, inputSquared, negHalf); Value expRes = rewriter.create(loc, outType, negHalfInputSquared); Value pdf = rewriter.create(loc, outType, kBeta, expRes); Value pdfTimesInput = rewriter.create(loc, outType, pdf, adaptor.getSelf()); Value pdfTimesInputAddCdf = rewriter.create(loc, outType, pdfTimesInput, cdf); rewriter.replaceOpWithNewOp( op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.getExponent(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto outTy = this->getTypeConverter()->convertType(op.getType()).cast(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenUniformOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value generator = adaptor.getGenerator(); Location loc = op.getLoc(); if (!generator.getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); auto elements = self.getType().cast().getShape(); if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); auto shape_tensor = rewriter.create( loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = outTy.cast().getElementType(); Value from = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); Value to = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getTo(), outElemTy); rewriter.replaceOpWithNewOp( op, outTy, from, to, shape_tensor, stablehlo::RngDistribution::UNIFORM); return success(); } // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // TODO: Add support pin_memory and memory_format features. // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; if (!op.getPinMemory().getType().template isa() && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Preserve) return rewriter.notifyMatchFailure( op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } if (!op.getDevice().getType().isa()) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( op, "unimplemented: device must be a constant str"); } // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( op, "unimplemented: layout must be a constant"); else if (tensorLayout != torch_upstream::Layout::Strided) return rewriter.notifyMatchFailure( op, "unimplemented: layout is expected to be strided"); } Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( op, "unimplemented: size must be constructed using ListConstruct"); } resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, resultSizeTorchInt); for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = typeConverter->convertType(op.getType()).cast(); Type resultElementType; if (op.getDtype().getType().isa()) { resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = getTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt, IntegerType::Signless); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape. Value initTensor = rewriter.create( loc, getAsOpFoldResult(resultSizeIndex), resultElementType); rewriter.replaceOpWithNewOp(op, resultType, initTensor); return success(); } // RuntimeAssertOp namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool condition; if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) { return rewriter.notifyMatchFailure( op, "unimplemented: condition must be a constant"); } if (!condition) { return op->emitError("condition must be true"); } rewriter.eraseOp(op); return success(); } }; } // namespace // AtenFillScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFillScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = getTypeConverter()->convertType(op.getType()).cast(); auto dtype = outType.getElementType(); Value scalarTensor = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); Value shapeTensor = rewriter.create(op->getLoc(), adaptor.getSelf()); Value bcastScalar = rewriter.create( op->getLoc(), outType, scalarTensor, shapeTensor, rewriter.getI64TensorAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } // AtenFlipOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenFlipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); SmallVector dims; if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) { return rewriter.notifyMatchFailure(op, "dims must be a list of const int"); } for (unsigned i = 0, e = dims.size(); i < e; i++) { dims[i] = toPositiveDim(dims[i], outType.getRank()); if (!isValidDim(dims[i], outType.getRank())) { return rewriter.notifyMatchFailure(op, "dim is statically invalid"); } } rewriter.replaceOpWithNewOp( op, outType, self, rewriter.getI64TensorAttr(dims)); return success(); } void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp); #undef INSERT_UNARY_PATTERN #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp); #undef INSERT_TENSOR_TO_SCALAR_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp); #undef INSERT_BINARY_MULDIV_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp); #undef INSERT_BINARY_LOGICAL_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenNumelOp); INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context) INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp); #undef INSERT_BINARY_BROADCAST_PATTERN }