[Torch Dialect] Decompose AtenTriuOp (#2561)

decompose like:
```
import torch

def my_triu(x, diag):
    rows = torch.ops.aten.size(x, -2)
    cols = torch.ops.aten.size(x, -1)

    row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
    col_indices = torch.ops.aten.arange(cols).unsqueeze(0)

    cond = torch.ops.aten.ge(
        col_indices, torch.ops.aten.add(row_indices, diag))
    return torch.ops.aten.where(cond, x, 0)

x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```

---------

Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
pull/2602/head
Mi Jiazhi 2023-11-29 10:35:26 +08:00 committed by GitHub
parent 49fdc1a8a6
commit f7a92d346e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 0 deletions

View File

@ -246,6 +246,62 @@ public:
};
} // end namespace
namespace {
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTriuOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || !inputType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
}
if (inputType.getSizes().size() < 2) {
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
}
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<ConstantNoneOp>(loc);
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value colDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
Value rowArange = rewriter.create<AtenArangeOp>(
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange = rewriter.create<AtenArangeOp>(
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value unsqueezeRowArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
Value unsqueezeColArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
Value condTensor = rewriter.create<AtenGeTensorOp>(
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, op.getResult().getType(), condTensor, input, cstZero);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
@ -5817,6 +5873,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTypeAsOp>();
target.addIllegalOp<AtenTileOp>();
target.addIllegalOp<AtenReshapeAsOp>();
target.addIllegalOp<AtenTriuOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));

View File

@ -3251,6 +3251,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
# ==============================================================================
class TriuModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4,5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.triu(x, 1)
@register_test_case(module_factory=lambda: TriuModule())
def TriuModule_basic(module, tu: TestUtils):
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
module.forward(x)
# ==============================================================================
class TriuBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3,4,5,6], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.triu(x, 2)
@register_test_case(module_factory=lambda: TriuBroadcastModule())
def TriuBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4,5,6))
# ==============================================================================
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
def __init__(self):