//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); int64_t intValue; auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only Tensor types supported in MHLO"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization supported"); } } }; } // namespace // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { template class ConvertAtenConstPatternOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); if (!outType) return op.emitError("only Tensor types supported in MHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) return op.emitError( "only floating-point or integer datatype legalization supported"); // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? if (!op.layout().getType().template isa()) return op.emitError("only default layout is supported"); bool pinMemory; if (!op.pin_memory().getType().template isa() && (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return op.emitError( "unsupported pin_memory, should be either None or false"); } SmallVector shape; if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { return op.emitError("shape must be a list of Scalar constants"); } int64_t size = 1; for (auto s : shape) size *= s; SmallVector values(size, fillVal); auto constOp = mhlo::getConstTensor(rewriter, op, values, shape).getValue(); rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; } // namespace // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); RankedTensorType lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in MHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); } lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); if (!skipMultiplyAlpha(op.alpha())) { Value alpha = mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); } }; } // namespace // Binary op legalizations for Mul/Div variants. namespace { template class ConvertAtenMulDivOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) return op.emitError("only Tensor types supported in MHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (std::is_same()) { rhs = lhs; } else if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); } DenseIntElementsAttr bcastDimensions; lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); if (!isa(op)) { rewriter.replaceOp(op, result); return success(); } AtenDivTensorModeOp divTensorModeOp = llvm::dyn_cast(op.getOperation()); std::string roundingMode; if (!matchPattern(divTensorModeOp.rounding_mode(), m_TorchConstantStr(roundingMode))) return rewriter.notifyMatchFailure( op, "only support constant str rounding mode"); if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); auto abs = rewriter.create(loc, result); auto floor = rewriter.create(loc, abs); result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) result = rewriter.create(loc, result).getResult(); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // Binary op legalizations for comparator ops. namespace { template class ConvertAtenCompareOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); Value rhs = adaptor.other(); RankedTensorType lhsTy = lhs.getType().dyn_cast(); RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) return op.emitError("only Tensor types supported in MHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); Type lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } if (!rhsTy) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy); } // TODO: what is the PyTorch default type promotion? rhs = mhlo::promoteType(rewriter, rhs, lhsTy); mhlo::ComparisonTypeAttr compareTypeAttr; mhlo::ComparisonDirectionAttr compareDirectionAttr; if (lhsElemTy.isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( op->getContext(), mhlo::ComparisonType::FLOAT); } else if (lhsElemTy.isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( op->getContext(), mhlo::ComparisonType::SIGNED); } if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::GT); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::EQ); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::NE); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); return success(); } }; } // namespace // AtenTransposeIntOp namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); int64_t dim0; if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) { return rewriter.notifyMatchFailure(op, "dim0 must be constant"); } int64_t dim1; if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) { return rewriter.notifyMatchFailure(op, "dim1 must be constant"); } auto inType = self.getType().cast(); auto inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) { return rewriter.notifyMatchFailure(op, "dim0 out of range"); } dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) { return rewriter.notifyMatchFailure(op, "dim1 out of range"); } SmallVector permValues(inputRank); std::iota(std::begin(permValues), std::end(permValues), 0); std::swap(permValues[dim0], permValues[dim1]); DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } }; } // namespace // AtenBroadcastToOp namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); #ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE if (selfTy.hasStaticShape()) { Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } #endif SmallVector shape; if (!(getListConstructElements(adaptor.size(), shape))) { return op->emitError("desired shape must be a list of scalar"); } SmallVector bcastShapeVec; int64_t totalRank = shape.size(); int64_t selfRank = selfTy.getRank(); int64_t leadingRank = totalRank - selfRank; for (int64_t i = 0; i < totalRank; ++i) { Value dValue = shape[i]; Value newD; int64_t dInt; if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { return op->emitError("element of desired shape must be a scalar"); } if (i >= leadingRank && dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { dValue = rewriter.create(op->getLoc(), dValue); newD = rewriter.create( op->getLoc(), rewriter.getIndexType(), dValue); } bcastShapeVec.push_back(newD); } #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 for (auto &dsize : bcastShapeVec) { auto dsizeI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), dsize); dsize = rewriter.create(op->getLoc(), rewriter.getI32Type(), dsizeI64); } #endif Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); return success(); } }; } // namespace // AtenPermuteOp namespace { class ConvertAtenPermuteOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); // Not a ranked tensor type auto inType = self.getType().dyn_cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); if (!inType) return op.emitError("only ranked tensor types with static shapes are " "currently supported"); SmallVector permValues; if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) return rewriter.notifyMatchFailure( op, "only constant dimensions are currently supported"); int64_t inRank = inType.getRank(); for (auto &d : permValues) { d = toPositiveDim(d, inRank); if (!isValidDim(d, inRank)) return op.emitError("not all dims are valid"); } DenseIntElementsAttr permutation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); rewriter.replaceOpWithNewOp(op, outType, self, permutation); return success(); } }; } // namespace namespace { template class ConvertAtenOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace // AtenTanhOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { return op.emitError( "only floating-point datatype legalization currently supported"); } } } // namespace // ValueTensorLiteralOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? if (auto elements = op.valueAttr().dyn_cast()) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); DenseElementsAttr valueAttr = elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); rewriter.replaceOpWithNewOp(op, resultType, valueAttr); return success(); } rewriter.replaceOpWithNewOp(op, resultType, adaptor.value()); return success(); } } // namespace // AtenReciprocalOp // Reciprocal(x) = Div(1, x) namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().cast(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (!inputTy.getElementType().isa()) { return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } } // namespace // PrimNumToTensorScalarOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { RankedTensorType outputType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); auto outputElemType = outputType.getElementType(); Value mhloTensor = mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType); rewriter.replaceOp(op, mhloTensor); return success(); } } // namespace // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) return op.emitError("only tensor types are currently supported"); // FIXME: memory_format is not handled. rewriter.replaceOp(op, adaptor.self()); return success(); } } // namespace // AtenReluOp // Relu(x) = Max(0, x) namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isa()) { return op->emitError("only float tensor in relu op is supported"); } Value zeroTensor; zeroTensor = chlo::getConstantLike( rewriter, op->getLoc(), APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } } // namespace // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.self(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); auto rsqrtTwo = rewriter.create(loc, two); auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); auto erfAdd = rewriter.create(loc, erf, one); auto halfMul = rewriter.create(loc, erfAdd, half); rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } } // namespace // AtenErfOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputType = input.getType().cast(); if (!inputType.getElementType().isa()) { return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), input); return success(); } } // namespace // AtenBatchNormOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.input(); // shape = [N, C, H, W] auto inputTy = input.getType().cast(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value runningMean = adaptor.running_mean(); Value runningVar = adaptor.running_var(); // momentum is ignored Value momentum = adaptor.momentum(); (void)momentum; if (inputTy.getRank() <= 2) { return rewriter.notifyMatchFailure(op, "input should have rank larger than 2"); } if (!inputTy.getElementType().template isa()) { return op.emitError("only input tensor of float type is supported"); } auto inputElemTy = inputTy.getElementType().cast(); Value channelDim = rewriter.create(op->getLoc(), input, 1); #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 auto channelDimI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), channelDim); channelDim = rewriter.create( op->getLoc(), rewriter.getI32Type(), channelDimI64); #endif Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { weight = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { bias = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { runningVar = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { runningMean = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); auto runningMeanTy = runningMean.getType().cast(); auto runningVarTy = runningVar.getType().cast(); if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } if (!weightTy.getElementType().template isa() || !biasTy.getElementType().template isa() || !runningMeanTy.getElementType().template isa() || !runningVarTy.getElementType().template isa()) { return op.emitError("only float weight/bias/runningMean/runningVar tensor " "of float type is supported"); } double eps = 0.0; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); } bool training = false; if (!matchPattern(op.training(), m_TorchConstantBool(&training))) { return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); } // TODO: handle cudnnEnabled parameter. Here, we just ignore it! bool cudnnEnabled = false; if (!matchPattern(op.cudnn_enabled(), m_TorchConstantBool(&cudnnEnabled))) { return rewriter.notifyMatchFailure(op, "non-bool cudnn_enabled unsupported"); } if (training) { Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); auto batchNormTrainingResult = rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, outputTy, input, weight, bias, runningMean, runningVar, rewriter.getFloatAttr(inputTy.getElementType(), eps), rewriter.getI64IntegerAttr(1)); return success(); } } } // namespace // AtenNativeLayerNormOp namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.input(); auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); if (!inputTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } SmallVector normalizedShape; if (!matchPattern(op.normalized_shape(), m_TorchConstantIntList(normalizedShape))) { return rewriter.notifyMatchFailure( op, "normalized_shape must be a list of const int"); } double eps = 0; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { return rewriter.notifyMatchFailure(op, "non const float eps is unsupported"); } if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) { return op->emitError("none weight or bias is unsupported"); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); if (!inputTy.getElementType().isa() || !biasTy.getElementType().isa() || !weightTy.getElementType().isa()) { return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); if (weightTy.getRank() != normalizedShapeRank || biasTy.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) { return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" "normalized shape not compatible"); } for (int64_t i = 1; i <= normalizedShapeRank; i++) { if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || weightTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i] || biasTy.getShape()[normalizedShapeRank - i] != normalizedShape[normalizedShapeRank - i]) { return op.emitError("mismatching contracting dimension"); } } // Flatten dims to fit batch_norm operation. int64_t numFeatureDimSize = 1; int64_t numEmbeddingDimSize = 1; for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { numFeatureDimSize *= inputShape[i]; } for (int64_t i = 0; i < normalizedShapeRank; i++) { numEmbeddingDimSize *= normalizedShape[i]; } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; auto mhloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); auto mhloBathNormOutMeanOrVarTy = RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); // Reshape input auto mhloInput = rewriter.create( op->getLoc(), mhloBatchNormOutTy, input, mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .getValue()); // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() .getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( inputTy.getElementType().cast().getFloatSemantics(), 1)); auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); auto batchNormTrainingResult = rewriter.create( op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outputMeanOrVarTy = getTypeConverter()->convertType(op.getType(1)).cast(); auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), mhlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) .getValue()); auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .getValue()); auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .getValue()); // Apply affine transform: output x weight + bias [element-wise] auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } } // namespace void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context); INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context); INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); #undef INSERT_BINARY_MULDIV_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); #undef INSERT_ATENOP_PATTERN }