mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like: ``` import torch def my_triu(x, diag): rows = torch.ops.aten.size(x, -2) cols = torch.ops.aten.size(x, -1) row_indices = torch.ops.aten.arange(rows).unsqueeze(1) col_indices = torch.ops.aten.arange(cols).unsqueeze(0) cond = torch.ops.aten.ge( col_indices, torch.ops.aten.add(row_indices, diag)) return torch.ops.aten.where(cond, x, 0) x = torch.rand(5, 7) assert torch.allclose(my_triu(x, 0), torch.triu(x, 0)) assert torch.allclose(my_triu(x, 1), torch.triu(x, 1)) assert torch.allclose(my_triu(x, 2), torch.triu(x, 2)) assert torch.allclose(my_triu(x, -1), torch.triu(x, -1)) ``` --------- Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>pull/2602/head
parent
49fdc1a8a6
commit
f7a92d346e
|
@ -246,6 +246,62 @@ public:
|
|||
};
|
||||
} // end namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTriuOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op.getContext();
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
||||
}
|
||||
if (inputType.getSizes().size() < 2) {
|
||||
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
|
||||
}
|
||||
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
Value cstZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-2));
|
||||
Value colDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
|
||||
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
|
||||
|
||||
Value rowArange = rewriter.create<AtenArangeOp>(
|
||||
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
Value colArange = rewriter.create<AtenArangeOp>(
|
||||
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
Value unsqueezeRowArange =
|
||||
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
|
||||
Value unsqueezeColArange =
|
||||
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
|
||||
|
||||
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
|
||||
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
|
||||
|
||||
Value condTensor = rewriter.create<AtenGeTensorOp>(
|
||||
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
|
||||
op, op.getResult().getType(), condTensor, input, cstZero);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||
public:
|
||||
|
@ -5817,6 +5873,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenTypeAsOp>();
|
||||
target.addIllegalOp<AtenTileOp>();
|
||||
target.addIllegalOp<AtenReshapeAsOp>();
|
||||
target.addIllegalOp<AtenTriuOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
|
|
|
@ -3251,6 +3251,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class TriuModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4,5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.triu(x, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TriuModule())
|
||||
def TriuModule_basic(module, tu: TestUtils):
|
||||
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
|
||||
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
|
||||
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
|
||||
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
|
||||
module.forward(x)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TriuBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3,4,5,6], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.triu(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TriuBroadcastModule())
|
||||
def TriuBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3,4,5,6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue