//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( cast(elementTy).getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getInf(cast(elementTy).getFloatSemantics(), /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getInf(cast(elementTy).getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } if (isa(op)) { if (isa(elementTy)) { APFloat one(cast(elementTy).getFloatSemantics(), 1); auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { APInt one(elementTy.getIntOrFloatBitWidth(), 1); auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, constAttr); } } if (isa(op)) { auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); return rewriter.create(op->getLoc(), constType, constAttr); } if (isa(op)) { auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); return rewriter.create(op->getLoc(), constType, constAttr); } op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; } static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, Type outTy, ArrayRef dims, PatternRewriter &rewriter) { auto inputTy = dyn_cast(input.getType()); if (!inputTy) return nullptr; Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return nullptr; stablehlo::ReduceOp reduce = rewriter.create( op->getLoc(), outTy, input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = reduce.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArgument = block.args_begin(); auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value result; if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else { op->emitError("unimplemented lowering in " "createReduceOpWithSingleRegionOp"); return nullptr; } rewriter.create(op->getLoc(), result); } return reduce.getResults()[0]; } // Util for converting AtenArgmaxOp and AtenMaxDimOp static std::optional getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, ArrayRef inputShapeVec, int64_t dim, size_t dimSizeIndexBits) { auto inputTy = cast(input.getType()); if (!inputTy) { return std::nullopt; } if (!inputTy.getElementType().isIntOrFloat()) { return std::nullopt; } auto inputShape = inputTy.getShape(); auto inputElemTy = inputTy.getElementType(); Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) return std::nullopt; Value initIndex; if (dimSizeIndexBits == 32) { initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } else { initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } std::vector outputShape(inputShape.begin(), inputShape.end()); outputShape.erase(outputShape.begin() + dim); auto outputTy = RankedTensorType::get(outputShape, inputElemTy); auto outputIndexTy = RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(dimSizeIndexBits)), inputShapeTensor, static_cast(dim)); auto stablehloReduceOp = rewriter.create( op->getLoc(), TypeRange{outputTy, outputIndexTy}, ValueRange{input, indexTensor}, ValueRange{ initValue, initIndex, }, rewriter.getDenseI64ArrayAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); // Add block arguments auto blockValArgumentType = RankedTensorType::get({}, inputTy.getElementType()); auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits)); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); auto *firstValArg = block.args_begin(); auto *firstIdxArg = std::next(firstValArg); auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); } else if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::GE); stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } return stablehloReduceOp.getResults(); } namespace { template class ConvertAtenReductionOp : public ConvertAtenOp { public: using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(false && "Unimplemented"); return failure(); }; }; template class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); auto outTy = dyn_cast( ConvertAtenReductionOp::getTypeConverter()->convertType( op.getType())); if (!inputTy || !outTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); } if (inputElemTy != outTy.getElementType()) { // use output type as computation type input = rewriter.create(op->getLoc(), input, outTy.getElementType()); } SmallVector dims = llvm::to_vector(llvm::seq(0, inputTy.getRank())); Value result = createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter); if (!result) { return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } rewriter.replaceOp(op, result); return success(); } }; template class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } SmallVector inputDims; SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); // Drop invalid dims if (isValidDim(d, inputTy.getRank())) { dims.push_back(d); } } llvm::sort(dims.begin(), dims.end()); std::unordered_set dimsSet(dims.begin(), dims.end()); SmallVector reduceResultShape; for (int64_t i = 0; i < inputTy.getRank(); i++) { if (dimsSet.find(i) == dimsSet.end()) { reduceResultShape.push_back(inputTy.getDimSize(i)); } } Value reduceResult = createReduceOpWithSingleRegionOp( op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, rewriter); if (!reduceResult) return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); if (keepDim) { const auto &options = ConvertAtenReductionOp::getOptions(); auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, ConvertAtenReductionOp::getTypeConverter()->convertType( op.getType()), reduceResult, outShapeTensor); return success(); } rewriter.replaceOp(op, reduceResult); return success(); } }; } // namespace // AtenArgmaxOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenArgmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported! if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenArgmaxOp to StableHLO"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); } dim = toPositiveDim(dim, inputTy.getRank()); if (!isValidDim(dim, inputTy.getRank())) { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } const auto &options = getOptions(); auto inputShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, options.dimSizeIndexBits) .value(); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], outShapeTensor); return success(); } rewriter.replaceOp(op, stablehloReduceResults[1]); return success(); } } // namespace // AtenMaxDimOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMaxDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "Only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenMaxDimOp to StableHLO"); } RankedTensorType valResultType = cast( getTypeConverter()->convertType(op.getResult(0).getType())); RankedTensorType idxResultType = cast( getTypeConverter()->convertType(op.getResult(1).getType())); Type idxElementType = idxResultType.getElementType(); if (!isa(idxElementType)) { return op.emitError("Aten.max.dim needs integer-like result"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); } dim = toPositiveDim(dim, inputTy.getRank()); if (!isValidDim(dim, inputTy.getRank())) { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } const auto &options = getOptions(); auto inputShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; if (op.getResult(1).use_empty()) { llvm::SmallVector outputShape(inputTy.getShape()); outputShape.erase(outputShape.begin() + dim); Value reduceResult = createReduceOpWithSingleRegionOp( op, input, RankedTensorType::get(outputShape, inputElemTy), ArrayRef{dim}, rewriter); if (!reduceResult) return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); auto stablehloReduceValueResult = rewriter.create( op->getLoc(), valResultType, reduceResult, outShapeTensor); rewriter.replaceOp(op, {stablehloReduceValueResult, Value()}); return success(); } rewriter.replaceOp(op, {reduceResult, Value()}); return success(); } else { auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, options.dimSizeIndexBits) .value(); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); auto stablehloReduceValueResult = rewriter.create( op->getLoc(), valResultType, stablehloReduceResults[0], outShapeTensor); auto stablehloReduceIndexResult = rewriter.create( op->getLoc(), idxResultType, stablehloReduceResults[1], outShapeTensor); rewriter.replaceOp( op, {stablehloReduceValueResult, stablehloReduceIndexResult}); return success(); } rewriter.replaceOp(op, {stablehloReduceResults[0], stablehloReduceResults[1]}); return success(); } } } // namespace // AtenSumDimIntListOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenSumDimIntListOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); auto outTy = dyn_cast(getTypeConverter()->convertType(op.getType())); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } if (inputTy.getElementType() != outTy.getElementType()) { // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "Only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenSumDimIntListOp to StableHLO"); } SmallVector inputDims; SmallVector dims; if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } else { if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } if (inputDims.size() == 0) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } } for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); // Drop invalid dims if (isValidDim(d, inputTy.getRank())) { dims.push_back(d); } } std::unordered_set dimsSet(dims.begin(), dims.end()); SmallVector reduceResultShape; for (int64_t i = 0; i < inputTy.getRank(); i++) { if (dimsSet.find(i) == dimsSet.end()) { reduceResultShape.push_back(inputTy.getDimSize(i)); } } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( op.getLoc(), RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArgument = block.args_begin(); auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); rewriter.create(op->getLoc(), addResult); } if (keepDim) { const auto &options = getOptions(); auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), stablehloReduceOp.getResult(0), outShapeTensor); return success(); } rewriter.replaceOpWithNewOp(op, outTy, stablehloReduceOp.getResults()); return success(); } } // namespace // AtenFrobeniusNormDimOp // aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given // dims) + stablehlo.sqrt namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenFrobeniusNormDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = dyn_cast(input.getType()); if (!inputType) { return op.emitError( "only ranked tensor input supported in AtenFrobeniusNormDimOp"); } auto inputRank = inputType.getRank(); auto inputElemType = inputType.getElementType(); if (!isa(inputElemType)) { return op.emitError( "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); } SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } for (auto &dim : dims) { dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) { return rewriter.notifyMatchFailure(op, "invalid dimension detected in `dim`"); } } // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); std::unordered_set dimsSet(dims.begin(), dims.end()); SmallVector reduceResultShape; for (int64_t i = 0; i < inputRank; i++) { if (dimsSet.find(i) == dimsSet.end()) { reduceResultShape.push_back(inputType.getDimSize(i)); } } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( op, "non-const bool `keepdim` is not supported"); } auto squareOp = rewriter.create(op->getLoc(), input, input); auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); if (!initValue) { return failure(); } auto reduceOp = rewriter.create( op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputElemType); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto firstArgument = *block.args_begin(); auto secondArgument = *block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); auto addResult = rewriter.create( op->getLoc(), firstArgument, secondArgument); rewriter.create(op->getLoc(), addResult.getResult()); } auto output = rewriter.create(op->getLoc(), reduceOp.getResult(0)); if (keepDim) { auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output, outShapeTensor); return success(); } rewriter.replaceOp(op, output.getResult()); return success(); } } // namespace // AtenLinalgVectorNormOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenLinalgVectorNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = dyn_cast(input.getType()); if (!inputType) { return op.emitError( "only ranked tensor input supported in AtenLinalgVectorNormOp"); } int64_t inputRank = inputType.getRank(); auto outType = cast(getTypeConverter()->convertType(op.getType())); auto outElemType = outType.getElementType(); if (!isa(outElemType)) { return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); } if (inputType.getElementType() != outType.getElementType()) { input = rewriter.create(op->getLoc(), input, outElemType); } Value ord = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOrd(), outElemType); SmallVector dims; if (failed(checkNotNone(rewriter, op, op.getDim()))) { dims = llvm::to_vector<4>(llvm::seq(0, inputRank)); } else { if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } for (auto &dim : dims) { dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) { return rewriter.notifyMatchFailure( op, "invalid dimension detected in `dim`"); } } // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); } std::unordered_set dimsSet(dims.begin(), dims.end()); SmallVector reduceResultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dimsSet.find(i) == dimsSet.end()) { reduceResultShape.push_back(inputType.getDimSize(i)); } } bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( op, "non-const bool `keepdim` is not supported"); } auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter); if (!initValue) { return failure(); } Value absValue = rewriter.create(op->getLoc(), input); Value powValue = rewriter.create(op->getLoc(), absValue, ord, nullptr); auto reduceOp = rewriter.create( op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, outElemType); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto firstArgument = *block.args_begin(); auto secondArgument = *block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); auto addResult = rewriter.create( op->getLoc(), firstArgument, secondArgument); rewriter.create(op->getLoc(), addResult.getResult()); } auto constantOne = rewriter.create( op->getLoc(), blockArgumentTy, DenseElementsAttr::get( blockArgumentTy, APFloat(cast(outElemType).getFloatSemantics(), 1))); auto reciprocalOrd = rewriter.create( op->getLoc(), blockArgumentTy, constantOne, ord); auto output = rewriter.create( op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr); if (keepDim) { auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output, outShapeTensor); return success(); } rewriter.replaceOp(op, output.getResult()); return success(); } } // namespace void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN #define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp); #undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN #define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp); INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp); #undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN }