[MHLO] Support aten.cumsum op in mhlo backend (#1825)

pull/1834/head snapshot-20230130.734
Jiahao Li 2023-01-30 13:38:27 +08:00 committed by GitHub
parent 6040d7ce00
commit f5b689e12f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 1 deletions

View File

@ -107,6 +107,8 @@ MHLO_PASS_SET = {
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",

View File

@ -34,7 +34,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
@ -531,6 +531,87 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
return success();
}
// AtenCumsumOp
template <>
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto inputShape = inputTy.getShape();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim must be a constant int");
}
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is out of range");
}
if (inputTy.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: cumsum dim must be static");
}
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
mhloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
mhloPadding[dim * 2] = inputShape[dim] - 1;
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()),
mhloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
sumBlock.addArgument(blockArgumentType, op->getLoc());
sumBlock.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = sumBlock.args_begin();
auto *secondArg = std::next(firstArg);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOp(op, reduceWindowSum.getResults());
return success();
}
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
@ -542,4 +623,6 @@ void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context, options);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
}

View File

@ -3035,6 +3035,23 @@ class CumsumStaticModule(torch.nn.Module):
def CumsumStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumsumStaticNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 7, 4], torch.float32, True),
])
def forward(self, val):
return torch.ops.aten.cumsum(val, dim=-1)
@register_test_case(module_factory=lambda: CumsumStaticNegativeDimModule())
def CumsumStaticNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
# ==============================================================================
class AtenToDeviceModule(torch.nn.Module):