linalg: lower `aten.triu` op to `linalg.generic` (#965)

Prior to this patch, the torch dialect included `AtenTriuOp` for
computing the upper triangular part of the input matrix, but there was
no code for lowering the op to the linalg dialect.

This patch adds code to generate a `linalg.generic` operation that
compares indices (computed using `linalg.index`) to choose between zero
or the original value (using `arith.select`).  The lowering fails if the
number of dimensions are less than two.  This patch also adds a few
end-to-end tests.
pull/951/head snapshot-20220624.513
Ashay Rane 2022-06-23 22:45:48 -07:00 committed by GitHub
parent 143a7bcb76
commit 234fc7fe0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 3 deletions

View File

@ -857,6 +857,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
}
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
// Check if the rank of the input tensor is valid.
AtenTriuOp::Adaptor adaptor(operands);
auto inputType = adaptor.self().getType().cast<RankedTensorType>();
uint64_t inputRank = inputType.getRank();
if (inputRank < 2) {
triu.emitError("too few dimensions to compute triangular part of matrix");
return nullptr;
}
// Use the indices of the two innermost dimensions.
auto rowIndex = b.create<linalg::IndexOp>(loc, inputRank - 2);
Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex);
auto colIndex = b.create<linalg::IndexOp>(loc, inputRank - 1);
Value colIndexI64 = castIndexToInt64(b, loc, colIndex);
// columnIndex >= rowIndex + diagonal?
auto sum = b.create<arith::AddIOp>(loc, rowIndexI64, adaptor.diagonal());
auto pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
colIndexI64, sum);
Value scalar = payloadArgs[0];
Type elementType = inputType.getElementType();
Value zero = getConstant(b, loc, 0, elementType);
return b.create<arith::SelectOp>(loc, pred, scalar, zero);
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
@ -902,7 +929,7 @@ public:
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenLogicalOrOp>(op))
AtenLogicalOrOp, AtenTriuOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1640,7 +1667,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenLogicalOrOp>();
AtenLogicalOrOp, AtenTriuOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp>(op)) {
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

View File

@ -5297,6 +5297,10 @@ module {
%3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>
return %3 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.triu"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.tanh"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -286,6 +286,9 @@ def not_present_in_registry(f):
# Shape functions
# ==============================================================================
def atentriu(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)
def atentanh(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -1806,3 +1806,69 @@ class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module):
module_factory=lambda: ElementwiseAtenFloorDivideBroadcastModule())
def ElementwiseAtenFloorDivideBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(4, 3))
# ==============================================================================
class AtenTriuModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.triu(x)
@register_test_case(module_factory=lambda: AtenTriuModule())
def AtenTriuModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 8, 3, 4, 3))
# ==============================================================================
class AtenTriuWithPosDiagonalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.triu(x, diagonal=2)
@register_test_case(module_factory=lambda: AtenTriuWithPosDiagonalModule())
def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 3))
# ==============================================================================
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.triu(x, diagonal=-4)
@register_test_case(module_factory=lambda: AtenTriuWithNegDiagonalModule())
def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 5, 9))

View File

@ -84,3 +84,13 @@ func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vt
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32>
}
// -----
func.func @insufficient_dims_for_triu(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0
// expected-error@+2 {{failed to legalize operation 'torch.aten.triu' that was explicitly marked illegal}}
// expected-error@+1 {{too few dimensions to compute triangular part of matrix}}
%0 = torch.aten.triu %arg0, %int0 : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}