[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632)

pull/2788/head
Franz Haniel 2024-01-22 18:47:13 +01:00 committed by GitHub
parent 50ac3b1912
commit b9806cfa38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 396 additions and 1 deletions

View File

@ -11306,6 +11306,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
}];
}
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::diagonal : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$offset,
Torch_IntType:$dim1,
Torch_IntType:$dim2
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiagonalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenDiagonalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1834,6 +1834,142 @@ public:
};
} // namespace
namespace {
class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
int64_t offset;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
return rewriter.notifyMatchFailure(op, "offset must be constant");
int64_t dim1;
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
int64_t dim2;
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
return rewriter.notifyMatchFailure(op, "dim2 must be constant");
Value inputMatrix = adaptor.getSelf();
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input must have at least two dimensions");
int64_t outputRank = inputRank - 1;
dim1 = toPositiveDim(dim1, inputRank);
if (!isValidDim(dim1, inputRank))
return rewriter.notifyMatchFailure(op, "dim1 out of range");
dim2 = toPositiveDim(dim2, inputRank);
if (!isValidDim(dim2, inputRank))
return rewriter.notifyMatchFailure(op, "dim2 out of range");
if (dim1 == dim2)
return rewriter.notifyMatchFailure(
op, "diagonal dimensions cannot be identical");
Type elementType = inputType.getElementType();
RankedTensorType outputType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Location loc = op.getLoc();
Value dim1Size, dim2Size;
dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1);
dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2);
// compute the length of the diagonal with possible offset
// if the offset is very large or very small, diagSize=0 and an empty tensor
// is returned
Value indexZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value indexMinusOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
Value indexOffset = rewriter.create<arith::ConstantIndexOp>(loc, offset);
Value offsetIsNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, indexOffset, indexZero);
Value sizeForNegativeOffset = rewriter.create<arith::MaxSIOp>(
loc,
rewriter.create<arith::MinSIOp>(
loc, rewriter.create<arith::AddIOp>(loc, dim1Size, indexOffset),
dim2Size),
indexZero);
Value sizeForPositiveOffset = rewriter.create<arith::MaxSIOp>(
loc,
rewriter.create<arith::MinSIOp>(
loc, rewriter.create<arith::SubIOp>(loc, dim2Size, indexOffset),
dim1Size),
indexZero);
Value diagSize = rewriter.create<arith::SelectOp>(
loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset);
// depending on its sign, the offset affects only the row or column indices
// of the diagonal
Value diagStart1 = rewriter.create<arith::SelectOp>(
loc, offsetIsNegative,
rewriter.create<arith::MulIOp>(loc, indexOffset, indexMinusOne),
indexZero);
Value diagStart2 = rewriter.create<arith::SelectOp>(loc, offsetIsNegative,
indexZero, indexOffset);
SmallVector<Value> outputDims;
for (auto i = 0; i < inputRank; i++) {
if (!(i == dim1 || i == dim2))
outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i));
}
outputDims.push_back(diagSize);
Value outputMatrix = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())};
SmallVector<utils::IteratorType> iteratorTypes(
outputRank, utils::IteratorType::parallel);
auto diagonal =
rewriter
.create<linalg::GenericOp>(
loc, outputMatrix.getType(), ValueRange{}, outputMatrix,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> diagIndices;
Value indexOnDiag =
b.create<linalg::IndexOp>(loc, outputRank - 1);
Value dim1Index =
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart1);
Value dim2Index =
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart2);
// specify at which input indices the diagonal values are
// extracted
for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) {
if (indIn == dim1)
diagIndices.push_back(dim1Index);
else if (indIn == dim2)
diagIndices.push_back(dim2Index);
else {
diagIndices.push_back(
b.create<linalg::IndexOp>(loc, indOut));
indOut++;
}
}
Value diagElt = b.create<tensor::ExtractOp>(
loc, elementType, inputMatrix, diagIndices);
b.create<linalg::YieldOp>(loc, diagElt);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, diagonal);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -1872,4 +2008,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsRealOp>();
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
target.addIllegalOp<AtenDiagonalOp>();
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
}

View File

@ -6238,6 +6238,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: input must have at least two dimensions\"\n"
" %int2 = torch.constant.int 2\n"
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %3 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg2, %2, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %5 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg3, %4, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %6 = torch.aten.ne.int %3, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %9 = torch.prim.ListConstruct %int9223372036854775807, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %10, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %19 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.eq.int %arg4, %3 : !torch.int, !torch.int -> !torch.bool\n"
" %21 = torch.prim.If %20 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %22 = torch.aten.eq.int %arg4, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %22 : !torch.bool\n"
" }\n"
" torch.prim.If %21 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %22 = torch.aten.append.t %7, %19 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %11 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.sub.int %12, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.prim.min.int %11, %13 : !torch.int, !torch.int -> !torch.int\n"
" %15 = torch.prim.max.int %14, %int0 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
" %19 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.add.int %19, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %22 = torch.prim.min.int %20, %21 : !torch.int, !torch.int -> !torch.int\n"
" %23 = torch.prim.max.int %22, %int0 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %23 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %15 : !torch.int\n"
" }\n"
" %18 = torch.aten.append.t %7, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" return %7 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -9980,6 +10048,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.diagonal\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -222,7 +222,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
AtenPixelShuffleOp>(op);
AtenPixelShuffleOp, AtenDiagonalOp>(op);
}
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

View File

@ -59,6 +59,36 @@ def atentriu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
def atentril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)
@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
Invocation(TensorOfShape(2, 3, 4), dim1=-1, dim2=-2, offset=1), # Positive `offset`.
Invocation(TensorOfShape(2, 3, 4), offset=-1), # Negative `offset``.
Invocation(TensorOfShape(2, 3, 4), offset=3), # Empty result due to large `offset`.
ErrorInvocation(TensorOfShape(2)), # Input one-dimensional.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=3, dim2=1), # `dim1` out of bounds.
])
def atendiagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> List[int]:
assert len(self) >= 2, "input must have at least two dimensions"
dim1 = upstream_shape_functions.maybe_wrap_dim(dim1, len(self))
dim2 = upstream_shape_functions.maybe_wrap_dim(dim2, len(self))
assert dim1 != dim2, "diagonal dimensions cannot be identical"
diagonal: List[int] = []
for i, self_dim in enumerate(self):
if (i==dim1) or (i==dim2):
pass
else:
diagonal.append(self_dim)
diag_size = max(min(self[dim1], self[dim2] - offset), 0)
if offset<0:
diag_size = max(min(self[dim1] + offset, self[dim2]), 0)
diagonal.append(diag_size)
return diagonal
def atentan〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2493,6 +2523,11 @@ def atentril〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) ->
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim1=0, dim2=1))
def atendiagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenuniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -672,6 +672,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)")

View File

@ -60,3 +60,4 @@ def register_all_tests():
from . import control_flow
from . import stats
from . import padding
from . import diagonal

View File

@ -0,0 +1,123 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class DiagonalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a)
@register_test_case(module_factory=lambda: DiagonalModule())
def DiagonalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))
@register_test_case(module_factory=lambda: DiagonalModule())
def DiagonalModule_nonsquare(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class DiagonalTransposedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a, dim1=1, dim2=0)
@register_test_case(module_factory=lambda: DiagonalTransposedModule())
def DiagonalModule_transposed(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class DiagonalWithDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a, dim1=0, dim2=1)
@register_test_case(module_factory=lambda: DiagonalWithDimsModule())
def DiagonalModule_with_dims(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class DiagonalWithNegativeDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1)
@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule())
def DiagonalModule_with_negative_dims(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class DiagonalWithOffsetModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a, offset=1)
@register_test_case(module_factory=lambda: DiagonalWithOffsetModule())
def DiagonalModule_with_offset(module, tu: TestUtils):
module.forward(tu.rand(4, 6))
# ==============================================================================
class DiagonalWithDimsOffsetModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1)
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))