//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } // Max pooling if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } op->emitError("unimplemented lowering in AtenPoolingOp"); return nullptr; } namespace { template class ConvertAtenPoolingOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace // AtenMaxPool2dOp namespace { template <> LogicalResult ConvertAtenPoolingOp::matchAndRewrite( AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (inputRank <= 2) { return op.emitError( "max_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride, dilation; bool ceilMode = false; if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { return rewriter.notifyMatchFailure(op, "non-const int dilation unsupported!"); } if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector mhloStride(inputRank, 1); SmallVector mhloDilation(inputRank, 1); SmallVector mhloKernelSize(inputRank, 1); SmallVector mhloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), mhloDilation.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), mhloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); mhloPadding[mhloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), mhloPadding); auto reduceWindowOp = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &block = reduceWindowOp.body().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArg = block.args_begin(); auto secondArg = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value result = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), result); } rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } } // namespace // AtenMaxPool2dWithIndicesOp namespace { template <> LogicalResult ConvertAtenPoolingOp::matchAndRewrite( AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); auto outValTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outIdxTy = getTypeConverter()->convertType(op.getType(1)).cast(); if (inputRank <= 2) { return op.emitError( "max_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride, dilation; bool ceilMode = false; if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { return rewriter.notifyMatchFailure(op, "non-const int dilation unsupported!"); } if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector mhloStride(inputRank, 1); SmallVector mhloDilation(inputRank, 1); SmallVector mhloKernelSize(inputRank, 1); SmallVector mhloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), mhloDilation.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), mhloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); mhloPadding[mhloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), mhloPadding); auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); SmallVector initIndexShapeVec; for (int64_t i = 0; i < inputRank - 2; i++) initIndexShapeVec.push_back(inputShapeVec[i]); initIndexShapeVec.push_back(rewriter.create( op->getLoc(), inputShapeVec[inputRank - 1], inputShapeVec[inputRank - 2])); auto initIndexShapeTensor = rewriter.create( op->getLoc(), initIndexShapeVec); SmallVector initIndexShapeForType(inputShape.begin(), inputShape.end() - 2); if (inputShape[inputRank - 1] == ShapedType::kDynamicSize || inputShape[inputRank - 2] == ShapedType::kDynamicSize) { initIndexShapeForType.push_back(ShapedType::kDynamicSize); } else { initIndexShapeForType.push_back(inputShape[inputRank - 1] * inputShape[inputRank - 2]); } auto initIndexTensor = rewriter .create( op->getLoc(), RankedTensorType::get(initIndexShapeForType, rewriter.getI64Type()), initIndexShapeTensor, static_cast(inputRank - 2)) .getResult(); auto indexTensor = rewriter .create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), initIndexTensor, inputShapeTensor) .getResult(); Value initIdx = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); auto reduceWindowOp = rewriter.create( op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &block = reduceWindowOp.body().emplaceBlock(); // Add bb argument auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); auto *firstValArg = block.args_begin(); auto *firstIdxArg = std::next(firstValArg); auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); mhlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( rewriter.getContext(), mhlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( rewriter.getContext(), mhlo::ComparisonType::SIGNED); } mhlo::ComparisonDirectionAttr compareGeDirectionAttr = mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), mhlo::ComparisonDirection::GE); mhlo::ComparisonDirectionAttr compareEqDirectionAttr = mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), mhlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } } // namespace // AtenAvgPool2dOp namespace { template <> LogicalResult ConvertAtenPoolingOp::matchAndRewrite( AtenAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); auto outShape = outTy.getShape(); if (inputRank <= 2) { return op.emitError( "avg_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; bool countIncludePad = true; if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } if (!(matchPattern(op.count_include_pad(), m_TorchConstantBool(&countIncludePad)))) { return rewriter.notifyMatchFailure( op, "non-const bool count_include_pad unsupported!"); } if (succeeded(checkNotNone(rewriter, op, op.divisor_override()))) { return rewriter.notifyMatchFailure( op, "only None divisor_override supported for now!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector mhloStride(inputRank, 1); SmallVector mhloDilation(inputRank, 1); SmallVector mhloKernelSize(inputRank, 1); SmallVector mhloPadding(inputRank * 2, 0); std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), mhloKernelSize.begin() + inputRank - 2); mhloPadding[mhloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1]; Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), mhloPadding); auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &sumBlock = reduceWindowSum.body().emplaceBlock(); // Add bb argument auto blockArgumentType = RankedTensorType::get({}, inputElemTy); sumBlock.addArgument(blockArgumentType, op->getLoc()); sumBlock.addArgument(blockArgumentType, op->getLoc()); auto *firstArg = sumBlock.args_begin(); auto secondArg = sumBlock.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { Value divisor = mhlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); divisor = mhlo::promoteType(rewriter, divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); } // Use another mhlo.ReduceWindowOp to get the divisor Value windowSizeConst = mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto reduceWindowSize = rewriter.create( op->getLoc(), RankedTensorType::get(outShape, inputElemTy), windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &sizeBlock = reduceWindowSize.body().emplaceBlock(); // Add bb argument blockArgumentType = RankedTensorType::get({}, inputElemTy); sizeBlock.addArgument(blockArgumentType, op->getLoc()); sizeBlock.addArgument(blockArgumentType, op->getLoc()); firstArg = sizeBlock.args_begin(); secondArg = sizeBlock.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&sizeBlock); Value sumResult = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); } } // namespace void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); }