diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 502a837ea..73dbd9aef 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -71,7 +71,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -151,6 +151,21 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else { op->emitError("unimplemented lowering in " "createReduceOpWithSingleRegionOp"); @@ -278,7 +293,150 @@ public: using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; + ConversionPatternRewriter &rewriter) const override { + assert(false && "Unimplemented"); + return failure(); + }; +}; + +template +class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, + "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); + } + + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } + + SmallVector dims = + llvm::to_vector(llvm::seq(0, inputTy.getRank())); + Value result = + createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter); + if (!result) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, + "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + SmallVector inputDims; + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + llvm::sort(dims.begin(), dims.end()); + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputTy.getDimSize(i)); + } + } + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, + rewriter); + if (!reduceResult) + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + + if (keepDim) { + const auto &options = ConvertAtenReductionOp::getOptions(); + auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType()), + reduceResult, outShapeTensor); + return success(); + } + rewriter.replaceOp(op, reduceResult); + return success(); + } }; } // namespace @@ -419,7 +577,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, input, RankedTensorType::get(outputShape, inputElemTy), ArrayRef{dim}, rewriter); if (!reduceResult) - return failure(); + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -472,483 +630,6 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace -// AtenSumOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - if (inputTy.getElementType() != outTy.getElementType()) { - // Use output element type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = dyn_cast(input.getType()); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); - } - - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenAllOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenAllOp to StableHLO"); - } - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - - if (inputElemTy != outTy.getElementType()) { - // Use output bool type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); - inputElemTy = inputTy.getElementType(); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value allResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), allResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenAnyOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenAnyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenAllOp to StableHLO"); - } - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - - if (inputElemTy != outTy.getElementType()) { - // Use output bool type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); - inputElemTy = inputTy.getElementType(); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value anyResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), anyResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenProdOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - if (inputTy.getElementType() != outTy.getElementType()) { - // Use output element type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = dyn_cast(input.getType()); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenProdOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value mulResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), mulResult); - } - - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); - - return success(); -} -} // namespace - -// AtenAmaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenAmaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxOp to StableHLO"); - } - - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } - - SmallVector inputDims; - SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { - return rewriter.notifyMatchFailure( - op, "non-const integer `dim` is not supported"); - } - for (auto d : inputDims) { - d = toPositiveDim(d, inputTy.getRank()); - // Drop invalid dims - if (isValidDim(d, inputTy.getRank())) { - dims.push_back(d); - } - } - llvm::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } - - Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, - rewriter); - if (!reduceResult) - return failure(); - - if (keepDim) { - const auto &options = getOptions(); - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(outShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), reduceResult, - outShapeTensor); - return success(); - } - rewriter.replaceOp(op, reduceResult); - return success(); -} -} // namespace - -// AtenMaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxOp to StableHLO"); - } - - SmallVector dims = - llvm::to_vector(llvm::seq(0, inputTy.getRank())); - - Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter); - if (!reduceResult) - return failure(); - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), reduceResult); - return success(); -} -} // namespace - -// AtenMinOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMinOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value minResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), minResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - // AtenSumDimIntListOp namespace { template <> @@ -1334,17 +1015,33 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp); +#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + + INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp); + INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp); +#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index ceccb38be..ad7889057 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7257,6 +7257,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.amin\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -12512,6 +12519,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amin\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3fa3184d1..a2e490338 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -814,6 +814,7 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { } STABLEHLO_PASS_SET = { + "ReduceAminSingleDim_basic", "AtenDotModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 81a860892..01a38c0fe 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -678,6 +678,9 @@ def aten〇min〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) - def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) @@ -4162,6 +4165,10 @@ def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇max〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇min〡dtype(self_rank_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇min〡dtype(self_rank_dtype), torch.int64 diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 359316a2b..c525267c8 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -212,7 +212,12 @@ BACKEND_LEGAL_OPS = { "aten.adaptive_avg_pool2d", "aten.unflatten.int", ], - OutputType.STABLEHLO: ["aten.amax"], + OutputType.STABLEHLO: [ + "aten.amax", + "aten.amin", + "aten.randn.generator", + "aten.normal_functional", + ], } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 9e0869dd9..4891d6eaa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1230,6 +1230,29 @@ def ReduceAmaxKeepDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAminSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amin(a, 1) + + +@register_test_case(module_factory=lambda: ReduceAminSingleDim()) +def ReduceAminSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceMinFloatModule(torch.nn.Module): def __init__(self): super().__init__()