mirror of https://github.com/llvm/torch-mlir
[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo (#3348)
* not to decompose `aten.amax` on `stablehlo` backend. Because it could be lowering to `stablehlo.reduce` directly. * lowering `aten.max.dim` to `stablehlo.reduce apply max` when `AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.pull/3349/head
parent
6b95dd461d
commit
5928f68e60
|
@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
}
|
||||
}
|
||||
|
||||
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
||||
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
|
@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
||||
Type outTy,
|
||||
ArrayRef<int64_t> dims,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy)
|
||||
return nullptr;
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue)
|
||||
return nullptr;
|
||||
|
||||
stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), outTy, input, initValue,
|
||||
rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = reduce.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result;
|
||||
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
|
||||
result = rewriter.create<stablehlo::MaxOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else {
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createReduceOpWithSingleRegionOp");
|
||||
return nullptr;
|
||||
}
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
return reduce.getResults()[0];
|
||||
}
|
||||
|
||||
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
||||
static std::optional<ValueRange>
|
||||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||
|
@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
if (op.getResult(1).use_empty()) {
|
||||
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
|
||||
outputShape.erase(outputShape.begin() + dim);
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get(outputShape, inputElemTy),
|
||||
ArrayRef<int64_t>{dim}, rewriter);
|
||||
if (!reduceResult)
|
||||
return failure();
|
||||
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||
outShapeTensor);
|
||||
auto stablehloReduceIndexResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, reduceResult, outShapeTensor);
|
||||
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOp(op, {reduceResult, Value()});
|
||||
return success();
|
||||
} else {
|
||||
auto stablehloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||
outShapeTensor);
|
||||
auto stablehloReduceIndexResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOp(op,
|
||||
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op,
|
||||
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -692,6 +761,92 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// AtenAmaxOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
|
||||
AtenAmaxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
// Currently, (u)int8 dtype is not supported
|
||||
if (isa<mlir::IntegerType>(inputElemTy) &&
|
||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxOp to StableHLO");
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> inputDims;
|
||||
SmallVector<int64_t> dims;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const integer `dim` is not supported");
|
||||
}
|
||||
for (auto d : inputDims) {
|
||||
d = toPositiveDim(d, inputTy.getRank());
|
||||
// Drop invalid dims
|
||||
if (isValidDim(d, inputTy.getRank())) {
|
||||
dims.push_back(d);
|
||||
}
|
||||
}
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
||||
SmallVector<int64_t> reduceResultShape;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
if (dimsSet.find(i) == dimsSet.end()) {
|
||||
reduceResultShape.push_back(inputTy.getDimSize(i));
|
||||
}
|
||||
}
|
||||
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
|
||||
rewriter);
|
||||
if (!reduceResult)
|
||||
return failure();
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = getOptions();
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), reduceResult,
|
||||
outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOp(op, reduceResult);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenMaxOp
|
||||
namespace {
|
||||
template <>
|
||||
|
@ -717,40 +872,16 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
"AtenMaxOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
SmallVector<int64_t> dims =
|
||||
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue)
|
||||
Value reduceResult = createReduceOpWithSingleRegionOp(
|
||||
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
|
||||
if (!reduceResult)
|
||||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
|
||||
rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value maxResult = rewriter.create<stablehlo::MaxOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
stablehloReduceOp.getResults());
|
||||
op, getTypeConverter()->convertType(op.getType()), reduceResult);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
|
||||
|
|
|
@ -212,7 +212,7 @@ BACKEND_LEGAL_OPS = {
|
|||
"aten.adaptive_avg_pool2d",
|
||||
"aten.unflatten.int",
|
||||
],
|
||||
OutputType.STABLEHLO: [],
|
||||
OutputType.STABLEHLO: ["aten.amax"],
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue