[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo (#3348)

* not to decompose `aten.amax` on `stablehlo` backend. Because it could
be lowering to `stablehlo.reduce` directly.
* lowering `aten.max.dim` to `stablehlo.reduce apply max` when
`AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.
pull/3349/head
Yuanqiang Liu 2024-05-16 00:05:19 +08:00 committed by GitHub
parent 6b95dd461d
commit 5928f68e60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 189 additions and 57 deletions

View File

@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
return nullptr;
}
static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
Type outTy,
ArrayRef<int64_t> dims,
PatternRewriter &rewriter) {
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy)
return nullptr;
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return nullptr;
stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), outTy, input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = reduce.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result;
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
result = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else {
op->emitError("unimplemented lowering in "
"createReduceOpWithSingleRegionOp");
return nullptr;
}
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
return reduce.getResults()[0];
}
// Util for converting AtenArgmaxOp and AtenMaxDimOp
static std::optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
if (op.getResult(1).use_empty()) {
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
outputShape.erase(outputShape.begin() + dim);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult)
return failure();
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, reduceResult, outShapeTensor);
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
return success();
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
} else {
auto stablehloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success();
}
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
} // namespace
@ -692,6 +761,92 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
}
} // namespace
// AtenAmaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
AtenAmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO");
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
rewriter);
if (!reduceResult)
return failure();
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), reduceResult,
outShapeTensor);
return success();
}
rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace
// AtenMaxOp
namespace {
template <>
@ -717,40 +872,16 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
"AtenMaxOp to StableHLO");
}
SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}
SmallVector<int64_t> dims =
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
if (!reduceResult)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
op, getTypeConverter()->convertType(op.getType()), reduceResult);
return success();
}
} // namespace
@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);

View File

@ -212,7 +212,7 @@ BACKEND_LEGAL_OPS = {
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
],
OutputType.STABLEHLO: [],
OutputType.STABLEHLO: ["aten.amax"],
}