mirror of https://github.com/llvm/torch-mlir
linalg: lower `aten.triu` op to `linalg.generic` (#965)
Prior to this patch, the torch dialect included `AtenTriuOp` for computing the upper triangular part of the input matrix, but there was no code for lowering the op to the linalg dialect. This patch adds code to generate a `linalg.generic` operation that compares indices (computed using `linalg.index`) to choose between zero or the original value (using `arith.select`). The lowering fails if the number of dimensions are less than two. This patch also adds a few end-to-end tests.pull/951/head snapshot-20220624.513
parent
143a7bcb76
commit
234fc7fe0c
|
@ -857,6 +857,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||
}
|
||||
|
||||
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
||||
// Check if the rank of the input tensor is valid.
|
||||
AtenTriuOp::Adaptor adaptor(operands);
|
||||
auto inputType = adaptor.self().getType().cast<RankedTensorType>();
|
||||
uint64_t inputRank = inputType.getRank();
|
||||
if (inputRank < 2) {
|
||||
triu.emitError("too few dimensions to compute triangular part of matrix");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Use the indices of the two innermost dimensions.
|
||||
auto rowIndex = b.create<linalg::IndexOp>(loc, inputRank - 2);
|
||||
Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex);
|
||||
auto colIndex = b.create<linalg::IndexOp>(loc, inputRank - 1);
|
||||
Value colIndexI64 = castIndexToInt64(b, loc, colIndex);
|
||||
|
||||
// columnIndex >= rowIndex + diagonal?
|
||||
auto sum = b.create<arith::AddIOp>(loc, rowIndexI64, adaptor.diagonal());
|
||||
auto pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
colIndexI64, sum);
|
||||
|
||||
Value scalar = payloadArgs[0];
|
||||
Type elementType = inputType.getElementType();
|
||||
Value zero = getConstant(b, loc, 0, elementType);
|
||||
return b.create<arith::SelectOp>(loc, pred, scalar, zero);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
return nullptr;
|
||||
|
@ -902,7 +929,7 @@ public:
|
|||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp>(op))
|
||||
AtenLogicalOrOp, AtenTriuOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1640,7 +1667,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp>();
|
||||
AtenLogicalOrOp, AtenTriuOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp>(op)) {
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -5297,6 +5297,10 @@ module {
|
|||
%3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
return %3 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.triu"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.tanh"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -286,6 +286,9 @@ def not_present_in_registry(f):
|
|||
# Shape functions
|
||||
# ==============================================================================
|
||||
|
||||
def aten〇triu(self: List[int], diagonal: int = 0) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇tanh(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -1806,3 +1806,69 @@ class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module):
|
|||
module_factory=lambda: ElementwiseAtenFloorDivideBroadcastModule())
|
||||
def ElementwiseAtenFloorDivideBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTriuModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTriuModule())
|
||||
def AtenTriuModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 8, 3, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTriuWithPosDiagonalModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTriuWithPosDiagonalModule())
|
||||
def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(9, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=-4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTriuWithNegDiagonalModule())
|
||||
def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 5, 9))
|
||||
|
|
|
@ -84,3 +84,13 @@ func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vt
|
|||
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
|
||||
return %1 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @insufficient_dims_for_triu(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
// expected-error@+2 {{failed to legalize operation 'torch.aten.triu' that was explicitly marked illegal}}
|
||||
// expected-error@+1 {{too few dimensions to compute triangular part of matrix}}
|
||||
%0 = torch.aten.triu %arg0, %int0 : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue