mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like: ``` import torch def my_triu(x, diag): rows = torch.ops.aten.size(x, -2) cols = torch.ops.aten.size(x, -1) row_indices = torch.ops.aten.arange(rows).unsqueeze(1) col_indices = torch.ops.aten.arange(cols).unsqueeze(0) cond = torch.ops.aten.ge( col_indices, torch.ops.aten.add(row_indices, diag)) return torch.ops.aten.where(cond, x, 0) x = torch.rand(5, 7) assert torch.allclose(my_triu(x, 0), torch.triu(x, 0)) assert torch.allclose(my_triu(x, 1), torch.triu(x, 1)) assert torch.allclose(my_triu(x, 2), torch.triu(x, 2)) assert torch.allclose(my_triu(x, -1), torch.triu(x, -1)) ``` --------- Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>pull/2602/head
parent
49fdc1a8a6
commit
f7a92d346e
|
@ -246,6 +246,62 @@ public:
|
||||||
};
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTriuOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = op.getSelf();
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
||||||
|
}
|
||||||
|
if (inputType.getSizes().size() < 2) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||||
|
Value cstZero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstOne =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
|
||||||
|
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(-2));
|
||||||
|
Value colDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
|
||||||
|
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
|
||||||
|
|
||||||
|
Value rowArange = rewriter.create<AtenArangeOp>(
|
||||||
|
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memory=*/none);
|
||||||
|
Value colArange = rewriter.create<AtenArangeOp>(
|
||||||
|
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memory=*/none);
|
||||||
|
|
||||||
|
Value unsqueezeRowArange =
|
||||||
|
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
|
||||||
|
Value unsqueezeColArange =
|
||||||
|
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
|
||||||
|
|
||||||
|
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
|
||||||
|
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
|
||||||
|
|
||||||
|
Value condTensor = rewriter.create<AtenGeTensorOp>(
|
||||||
|
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
|
||||||
|
op, op.getResult().getType(), condTensor, input, cstZero);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -5817,6 +5873,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
|
|
|
@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenTypeAsOp>();
|
target.addIllegalOp<AtenTypeAsOp>();
|
||||||
target.addIllegalOp<AtenTileOp>();
|
target.addIllegalOp<AtenTileOp>();
|
||||||
target.addIllegalOp<AtenReshapeAsOp>();
|
target.addIllegalOp<AtenReshapeAsOp>();
|
||||||
|
target.addIllegalOp<AtenTriuOp>();
|
||||||
for (auto &opName : backendLegalOpsSet) {
|
for (auto &opName : backendLegalOpsSet) {
|
||||||
target.addLegalOp(
|
target.addLegalOp(
|
||||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||||
|
|
|
@ -3251,6 +3251,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TriuModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([4,5], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.triu(x, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TriuModule())
|
||||||
|
def TriuModule_basic(module, tu: TestUtils):
|
||||||
|
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
|
||||||
|
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
|
||||||
|
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
|
||||||
|
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
|
||||||
|
module.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TriuBroadcastModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3,4,5,6], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.triu(x, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TriuBroadcastModule())
|
||||||
|
def TriuBroadcastModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3,4,5,6))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue