[Stablehlo] refactor reduction lowering and support aten.amin (#3383)

* implement detailed lowering template pattern
`ConvertAtenReduceAllDimsOp` and `ConvertAtenReduceKeepDimOp`
* support `aten.amin`'s lowering.
pull/3410/head
Yuanqiang Liu 2024-05-23 20:40:20 +08:00 committed by GitHub
parent 43f961eca4
commit 5bb1a65ec9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 232 additions and 488 deletions

View File

@ -71,7 +71,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}
if (isa<AtenMinOp>(op)) {
if (isa<AtenAminOp, AtenMinOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -151,6 +151,21 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
result = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAminOp, AtenMinOp>(op)) {
result = rewriter.create<stablehlo::MinOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenSumOp>(op)) {
result = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAllOp>(op)) {
result = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAnyOp>(op)) {
result = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenProdOp>(op)) {
result = rewriter.create<stablehlo::MulOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else {
op->emitError("unimplemented lowering in "
"createReduceOpWithSingleRegionOp");
@ -278,7 +293,150 @@ public:
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
ConversionPatternRewriter &rewriter) const override {
assert(false && "Unimplemented");
return failure();
};
};
template <typename AtenOpT>
class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp<AtenOpT> {
public:
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = dyn_cast<RankedTensorType>(
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (!inputTy || !outTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op,
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
}
if (inputElemTy != outTy.getElementType()) {
// use output type as computation type
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
outTy.getElementType());
}
SmallVector<int64_t> dims =
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
Value result =
createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter);
if (!result) {
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
}
rewriter.replaceOp(op, result);
return success();
}
};
template <typename AtenOpT>
class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp<AtenOpT> {
public:
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op,
"IntegerType with bitwidth 8 unsupported in convertion to StableHLO");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
rewriter);
if (!reduceResult)
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
if (keepDim) {
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op,
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
reduceResult, outShapeTensor);
return success();
}
rewriter.replaceOp(op, reduceResult);
return success();
}
};
} // namespace
@ -419,7 +577,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult)
return failure();
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
if (keepDim) {
auto outShapeVec = inputShapeVec;
@ -472,483 +630,6 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
}
} // namespace
// AtenSumOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
AtenSumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = dyn_cast<RankedTensorType>(input.getType());
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumOp to StableHLO");
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input,
initValue, rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenAllOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
AtenAllOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value allResult = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenAnyOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
AtenAnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value anyResult = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenProdOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
AtenProdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = dyn_cast<RankedTensorType>(input.getType());
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenProdOp to StableHLO");
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input,
initValue, rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value mulResult = rewriter.create<stablehlo::MulOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), mulResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenAmaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
AtenAmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
rewriter);
if (!reduceResult)
return failure();
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), reduceResult,
outShapeTensor);
return success();
}
rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace
// AtenMaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO");
}
SmallVector<int64_t> dims =
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
if (!reduceResult)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), reduceResult);
return success();
}
} // namespace
// AtenMinOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
AtenMinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMinOp to StableHLO");
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value minResult = rewriter.create<stablehlo::MinOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), minResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenSumDimIntListOp
namespace {
template <>
@ -1334,17 +1015,33 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceAllDimsOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp);
#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN
#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReduceKeepDimOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp);
#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN
}

View File

@ -7257,6 +7257,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.amin\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %none : !torch.none to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -12512,6 +12519,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.amin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.int {\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple<int, int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple<int, int>) -> !torch.int\n"

View File

@ -814,6 +814,7 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"ReduceAminSingleDim_basic",
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",

View File

@ -678,6 +678,9 @@ def atenmindim〡shape(self: List[int], dim: int, keepdim: bool = False) -
def atenamax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenamin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenmeandim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
@ -4162,6 +4165,10 @@ def atenamax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
def atenmaxdim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return atenmax〡dtype(self_rank_dtype), torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenamin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int:
return atenmin〡dtype(self_rank_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def atenmindim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return atenmin〡dtype(self_rank_dtype), torch.int64

View File

@ -212,7 +212,12 @@ BACKEND_LEGAL_OPS = {
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
],
OutputType.STABLEHLO: ["aten.amax"],
OutputType.STABLEHLO: [
"aten.amax",
"aten.amin",
"aten.randn.generator",
"aten.normal_functional",
],
}

View File

@ -1230,6 +1230,29 @@ def ReduceAmaxKeepDim_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceAminSingleDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.amin(a, 1)
@register_test_case(module_factory=lambda: ReduceAminSingleDim())
def ReduceAminSingleDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, high=100))
# ==============================================================================
class ReduceMinFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()