//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } // Max pooling if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } op->emitError("unimplemented lowering in AtenPoolingOp"); return nullptr; } // AtenMaxPool2dOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (inputRank <= 2) { return op.emitError( "max_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride, dilation; bool ceilMode = false; if (!(matchPattern(op.getKernelSize(), m_TorchListOfConstantInts(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { return rewriter.notifyMatchFailure(op, "non-const int dilation unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector stablehloStride(inputRank, 1); SmallVector stablehloDilation(inputRank, 1); SmallVector stablehloKernelSize(inputRank, 1); SmallVector stablehloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), stablehloDilation.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(), stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), stablehloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), stablehloPadding); auto reduceWindowOp = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &block = reduceWindowOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArg = block.args_begin(); auto secondArg = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value result = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), result); } rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); auto outValTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outIdxTy = getTypeConverter()->convertType(op.getType(1)).cast(); if (inputRank <= 2) { return op.emitError( "max_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride, dilation; bool ceilMode = false; if (!(matchPattern(op.getKernelSize(), m_TorchListOfConstantInts(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { return rewriter.notifyMatchFailure(op, "non-const int dilation unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector stablehloStride(inputRank, 1); SmallVector stablehloDilation(inputRank, 1); SmallVector stablehloKernelSize(inputRank, 1); SmallVector stablehloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), stablehloDilation.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(), stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), stablehloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), stablehloPadding); const auto &options = getOptions(); auto inputShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); SmallVector initIndexShapeVec; for (int64_t i = 0; i < inputRank - 2; i++) initIndexShapeVec.push_back(inputShapeVec[i]); initIndexShapeVec.push_back(rewriter.create( op->getLoc(), inputShapeVec[inputRank - 1], inputShapeVec[inputRank - 2])); auto initIndexShapeTensor = rewriter.create( op->getLoc(), initIndexShapeVec); SmallVector initIndexShapeForType(inputShape.begin(), inputShape.end() - 2); if (inputShape[inputRank - 1] == ShapedType::kDynamic || inputShape[inputRank - 2] == ShapedType::kDynamic) { initIndexShapeForType.push_back(ShapedType::kDynamic); } else { initIndexShapeForType.push_back(inputShape[inputRank - 1] * inputShape[inputRank - 2]); } auto initIndexTensor = rewriter .create( op->getLoc(), RankedTensorType::get(initIndexShapeForType, rewriter.getI64Type()), initIndexShapeTensor, static_cast(inputRank - 2)) .getResult(); auto indexTensor = rewriter .create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), initIndexTensor, inputShapeTensor) .getResult(); Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); auto reduceWindowOp = rewriter.create( op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &block = reduceWindowOp.getBody().emplaceBlock(); // Add bb argument auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); auto *firstValArg = block.args_begin(); auto *firstIdxArg = std::next(firstValArg); auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::GE); stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } // AtenAvgPool2dOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); auto outShape = outTy.getShape(); if (inputRank <= 2) { return op.emitError( "avg_pooling2d only supports inputs with rank higher than 2"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; bool countIncludePad = true; if (!(matchPattern(op.getKernelSize(), m_TorchListOfConstantInts(kernelSize)))) { return rewriter.notifyMatchFailure( op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { return rewriter.notifyMatchFailure(op, "non-const int padding unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { return rewriter.notifyMatchFailure(op, "non-const bool ceil_mode unsupported!"); } if (!(matchPattern(op.getCountIncludePad(), m_TorchConstantBool(&countIncludePad)))) { return rewriter.notifyMatchFailure( op, "non-const bool count_include_pad unsupported!"); } if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) { return rewriter.notifyMatchFailure( op, "only None divisor_override supported for now!"); } // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input SmallVector stablehloStride(inputRank, 1); SmallVector stablehloDilation(inputRank, 1); SmallVector stablehloKernelSize(inputRank, 1); SmallVector stablehloPadding(inputRank * 2, 0); std::copy(stride.begin(), stride.end(), stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), stablehloKernelSize.begin() + inputRank - 2); stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), stablehloPadding); auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); // Add bb argument auto blockArgumentType = RankedTensorType::get({}, inputElemTy); sumBlock.addArgument(blockArgumentType, op->getLoc()); sumBlock.addArgument(blockArgumentType, op->getLoc()); auto *firstArg = sumBlock.args_begin(); auto secondArg = sumBlock.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { Value divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); divisor = hlo::promoteType(rewriter, divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); } // Use another stablehlo.ReduceWindowOp to get the divisor Value windowSizeConst = hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); const auto &options = getOptions(); auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto reduceWindowSize = rewriter.create( op->getLoc(), RankedTensorType::get(outShape, inputElemTy), windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock(); // Add bb argument blockArgumentType = RankedTensorType::get({}, inputElemTy); sizeBlock.addArgument(blockArgumentType, op->getLoc()); sizeBlock.addArgument(blockArgumentType, op->getLoc()); firstArg = sizeBlock.args_begin(); secondArg = sizeBlock.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&sizeBlock); Value sumResult = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); } // AtenCumsumOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCumsumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure( op, "unimplemented: dim must be a constant int"); } dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) { return rewriter.notifyMatchFailure(op, "dim is out of range"); } if (inputTy.isDynamicDim(dim)) { return rewriter.notifyMatchFailure( op, "unimplemented: cumsum dim must be static"); } Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); SmallVector stablehloKernelSize(inputRank, 1); stablehloKernelSize[dim] = inputShape[dim]; SmallVector stablehloStride(inputRank, 1); SmallVector stablehloDilation(inputRank, 1); SmallVector stablehloPadding(inputRank * 2, 0); stablehloPadding[dim * 2] = inputShape[dim] - 1; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), stablehloPadding); auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); // Add bb argument auto blockArgumentType = RankedTensorType::get({}, inputElemTy); sumBlock.addArgument(blockArgumentType, op->getLoc()); sumBlock.addArgument(blockArgumentType, op->getLoc()); auto *firstArg = sumBlock.args_begin(); auto *secondArg = std::next(firstArg); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = rewriter.create(op->getLoc(), *firstArg, *secondArg); rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOp(op, reduceWindowSum.getResults()); return success(); } void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add>(typeConverter, context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); }