diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index d8d7d43c4..c9a2ad2e7 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -30,6 +30,18 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; +static SmallVector getReduceOutputShape(ArrayRef inputShape, + ArrayRef dims) { + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (size_t i = 0; i < inputShape.size(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputShape[i]); + } + } + return reduceResultShape; +} + static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); @@ -42,8 +54,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, @@ -59,8 +70,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); @@ -69,7 +79,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -77,8 +87,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); @@ -93,8 +102,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { APInt one(elementTy.getIntOrFloatBitWidth(), 1); auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, @@ -103,13 +111,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); return rewriter.create(op->getLoc(), constType, constAttr); } - if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + if (isa(op)) { + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)}); return rewriter.create(op->getLoc(), constType, constAttr); } @@ -149,16 +159,17 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { @@ -174,11 +185,11 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, return reduce.getResults()[0]; } -// Util for converting AtenArgmaxOp and AtenMaxDimOp +// Util for converting AtenMaxDimOp/AtenMinDimOp static std::optional -getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, - ArrayRef inputShapeVec, int64_t dim, - size_t dimSizeIndexBits) { +createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op, + Value &input, ArrayRef inputShapeVec, + int64_t dim, size_t dimSizeIndexBits) { auto inputTy = cast(input.getType()); if (!inputTy) { return std::nullopt; @@ -199,8 +210,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } - std::vector outputShape(inputShape.begin(), inputShape.end()); - outputShape.erase(outputShape.begin() + dim); + auto outputShape = getReduceOutputShape(inputShape, {dim}); auto outputTy = RankedTensorType::get(outputShape, inputElemTy); auto outputIndexTy = RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); @@ -252,6 +262,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareLeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::LE); stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::EQ); @@ -260,11 +273,21 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, - compareGeDirectionAttr, compareTypeAttr); + Value compareResult; + if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + } else if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareLeDirectionAttr, compareTypeAttr); + } else { + op->emitError("unimplement lowering of createReduceOpReturnIndices"); + return std::nullopt; + } Value retValResult = rewriter.create( - op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + op->getLoc(), compareResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. Value compareEqResult = rewriter.create( @@ -273,16 +296,35 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( - op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( - op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + op->getLoc(), ValueRange{retValResult, retIdxResult}); } return stablehloReduceOp.getResults(); } +static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter, + Location loc, Value reduceResult, + ArrayRef inputShapeVec, + Type outType, + ArrayRef dims, + size_t dimSizeIndexBits) { + SmallVector outShapeVec(inputShapeVec); + Value one = rewriter.create( + loc, + rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); + for (auto dim : dims) { + outShapeVec[dim] = one; + } + auto outShapeTensor = + rewriter.create(loc, outShapeVec); + return rewriter.create( + loc, outType, reduceResult, outShapeTensor); +} + namespace { template class ConvertAtenReductionOp : public ConvertAtenOp { @@ -320,14 +362,6 @@ public: return op.emitError( "only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, - "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); - } - if (inputElemTy != outTy.getElementType()) { // use output type as computation type input = rewriter.create(op->getLoc(), input, @@ -347,7 +381,7 @@ public: }; template -class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { +class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; @@ -356,7 +390,10 @@ public: ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } @@ -366,12 +403,78 @@ public: return op.emitError( "only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure( - op, - "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); + op, "non-const integer `dim` is not supported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), {dim}); + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim}, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + const auto &options = ConvertAtenReductionOp::getOptions(); + auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim}, + options.dimSizeIndexBits); + } + rewriter.replaceOp(op, reduceResult); + return success(); + } +}; + +template +class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); } bool keepDim = false; @@ -393,19 +496,16 @@ public: } } llvm::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, rewriter); - if (!reduceResult) + if (!reduceResult) { return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } if (keepDim) { const auto &options = ConvertAtenReductionOp::getOptions(); @@ -415,215 +515,104 @@ public: return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, - ConvertAtenReductionOp::getTypeConverter()->convertType( - op.getType()), - reduceResult, outShapeTensor); - return success(); + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, + options.dimSizeIndexBits); } rewriter.replaceOp(op, reduceResult); return success(); } }; -} // namespace -// AtenArgmaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenArgmaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported! - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenArgmaxOp to StableHLO"); - } - - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); - } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } - - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); - - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], - outShapeTensor); - return success(); - } - - rewriter.replaceOp(op, stablehloReduceResults[1]); - return success(); -} -} // namespace - -// AtenMaxDimOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxDimOp to StableHLO"); - } - - RankedTensorType valResultType = cast( - getTypeConverter()->convertType(op.getResult(0).getType())); - RankedTensorType idxResultType = cast( - getTypeConverter()->convertType(op.getResult(1).getType())); - Type idxElementType = idxResultType.getElementType(); - if (!isa(idxElementType)) { - return op.emitError("Aten.max.dim needs integer-like result"); - } - - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); - } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } - - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto inputShapeVec = *inputShapeInfo; - - if (op.getResult(1).use_empty()) { - llvm::SmallVector outputShape(inputTy.getShape()); - outputShape.erase(outputShape.begin() + dim); - Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get(outputShape, inputElemTy), - ArrayRef{dim}, rewriter); - if (!reduceResult) - return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); - - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, reduceResult, outShapeTensor); - rewriter.replaceOp(op, {stablehloReduceValueResult, Value()}); - return success(); +template +class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); } - rewriter.replaceOp(op, {reduceResult, Value()}); - return success(); - } else { - auto stablehloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim, - options.dimSizeIndexBits) - .value(); - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); + RankedTensorType valResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(0).getType())); + RankedTensorType idxResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(1).getType())); + Type idxElementType = idxResultType.getElementType(); + if (!isa(idxElementType)) { + return op.emitError("indices result should to be integer tyep"); + } - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, stablehloReduceResults[0], - outShapeTensor); - auto stablehloReduceIndexResult = - rewriter.create( - op->getLoc(), idxResultType, stablehloReduceResults[1], - outShapeTensor); + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + if (!isValidDim(dim, inputTy.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + const auto &options = ConvertAtenReductionOp::getOptions(); + auto inputShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + + if (op.getResult(1).use_empty()) { + llvm::SmallVector outputShape(inputTy.getShape()); + outputShape.erase(outputShape.begin() + dim); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(outputShape, inputElemTy), + ArrayRef{dim}, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType, + {dim}, options.dimSizeIndexBits); + } + rewriter.replaceOp(op, {reduceResult, Value()}); + return success(); + } else { + ValueRange stablehloReduceResults = + createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + if (keepDim) { + stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec, + valResultType, {dim}, options.dimSizeIndexBits); + stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec, + idxResultType, {dim}, options.dimSizeIndexBits); + } rewriter.replaceOp( - op, {stablehloReduceValueResult, stablehloReduceIndexResult}); + op, {stablehloReduceResults[0], stablehloReduceResults[1]}); return success(); } - rewriter.replaceOp(op, - {stablehloReduceResults[0], stablehloReduceResults[1]}); - return success(); - } -} + }; +}; } // namespace // AtenSumDimIntListOp @@ -653,17 +642,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "Only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumDimIntListOp to StableHLO"); - } - SmallVector inputDims; SmallVector dims; - if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } else { @@ -675,7 +655,6 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } } - for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); // Drop invalid dims @@ -683,46 +662,22 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( dims.push_back(d); } } + llvm::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), - RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = stablehloReduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } if (keepDim) { @@ -733,23 +688,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResult(0), outShapeTensor); - return success(); + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, + options.dimSizeIndexBits); } - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); + rewriter.replaceOp(op, reduceResult); return success(); } } // namespace @@ -789,18 +732,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "invalid dimension detected in `dim`"); } } - // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputRank; i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { @@ -810,36 +747,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto squareOp = rewriter.create(op->getLoc(), input, input); - auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); - if (!initValue) { - return failure(); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, squareOp.getResult(), + RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), - squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); - } - - auto output = - rewriter.create(op->getLoc(), reduceOp.getResult(0)); + Value output = rewriter.create(op->getLoc(), reduceResult); if (keepDim) { auto outShapeInfo = @@ -848,22 +763,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); + output = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), output, *outShapeInfo, + getTypeConverter()->convertType(op.getType()), dims, + options.dimSizeIndexBits); } - rewriter.replaceOp(op, output.getResult()); + rewriter.replaceOp(op, output); return success(); } } // namespace @@ -920,13 +825,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputType.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { @@ -934,46 +834,27 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter); - if (!initValue) { - return failure(); - } - Value absValue = rewriter.create(op->getLoc(), input); Value powValue = rewriter.create(op->getLoc(), absValue, ord, nullptr); - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), - powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, outElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } + + auto scalarType = RankedTensorType::get({}, outElemType); auto constantOne = rewriter.create( - op->getLoc(), blockArgumentTy, + op->getLoc(), scalarType, DenseElementsAttr::get( - blockArgumentTy, + scalarType, APFloat(cast(outElemType).getFloatSemantics(), 1))); auto reciprocalOrd = rewriter.create( - op->getLoc(), blockArgumentTy, constantOne, ord); - auto output = rewriter.create( - op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr); + op->getLoc(), scalarType, constantOne, ord); + Value output = rewriter.create( + op->getLoc(), reduceResult, reciprocalOrd, nullptr); if (keepDim) { auto outShapeInfo = @@ -982,23 +863,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); + output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output, + *outShapeInfo, outType, dims, + options.dimSizeIndexBits); } - - rewriter.replaceOp(op, output.getResult()); + rewriter.replaceOp(op, output); return success(); } } // namespace @@ -1010,9 +879,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) - - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); @@ -1022,7 +888,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) - INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp); @@ -1031,12 +896,25 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp); #undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN -#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \ +#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context, \ - options) + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp); +#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN - INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp); - INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp); -#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN +#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp); + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp); +#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp); + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp); +#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bc99fde51..6ac3ae099 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -32,6 +32,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { # unimplemented lowering torch -> linalg for torchvision.deform_conv2d # this is added to check the torch.onnx.export -> import_onnx -> torch path "DeformConv2D_basic", + "ReduceAnyDimFloatModule_basic", } LINALG_CRASHING_SET = { @@ -340,6 +341,7 @@ TORCHDYNAMO_CRASHING_SET = { } FX_IMPORTER_XFAIL_SET = { + "ReduceAnyDimFloatModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -502,7 +504,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ArgminIntModule_multiple_mins", "ArgminModule_basic", "ArgminModule_keepDim", - "ArgminModule_with_dim", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", @@ -716,10 +717,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ReduceAllDimFloat_basic", "ReduceAllDimInt_basic", "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceMinAlongDim_basic", "ReduceMinKeepDimReturnBoth_basic", "ReduceMinKeepDim_basic", "ReduceProdDimIntFloatModule_basic", @@ -832,6 +830,11 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { } STABLEHLO_PASS_SET = { + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDim_basic", + "ArgminModule_with_dim", + "ReduceMinAlongDimSignedInt_basic", + "ReduceAnyDimFloatModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", @@ -2198,6 +2201,7 @@ ONNX_XFAIL_SET = { # Failure - cast error "PermuteNegativeIndexModule_basic", # Failure - incorrect numerics + "ReduceAnyDimFloatModule_basic", "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 4891d6eaa..347a1f8cc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -239,6 +239,26 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) +class ReduceAnyDimFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.any(a, dim=0) + + +@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule()) +def ReduceAnyDimFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + # ==============================================================================