mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632)
parent
50ac3b1912
commit
b9806cfa38
|
@ -11306,6 +11306,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::diagonal : (Tensor, int, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$offset,
|
||||
Torch_IntType:$dim1,
|
||||
Torch_IntType:$dim2
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDiagonalOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenDiagonalOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1834,6 +1834,142 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||
return rewriter.notifyMatchFailure(op, "offset must be constant");
|
||||
int64_t dim1;
|
||||
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
|
||||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||
int64_t dim2;
|
||||
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
|
||||
return rewriter.notifyMatchFailure(op, "dim2 must be constant");
|
||||
|
||||
Value inputMatrix = adaptor.getSelf();
|
||||
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
if (inputRank < 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input must have at least two dimensions");
|
||||
int64_t outputRank = inputRank - 1;
|
||||
|
||||
dim1 = toPositiveDim(dim1, inputRank);
|
||||
if (!isValidDim(dim1, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
||||
dim2 = toPositiveDim(dim2, inputRank);
|
||||
if (!isValidDim(dim2, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim2 out of range");
|
||||
if (dim1 == dim2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "diagonal dimensions cannot be identical");
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
RankedTensorType outputType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value dim1Size, dim2Size;
|
||||
dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1);
|
||||
dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2);
|
||||
|
||||
// compute the length of the diagonal with possible offset
|
||||
// if the offset is very large or very small, diagSize=0 and an empty tensor
|
||||
// is returned
|
||||
Value indexZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value indexMinusOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
||||
Value indexOffset = rewriter.create<arith::ConstantIndexOp>(loc, offset);
|
||||
Value offsetIsNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sle, indexOffset, indexZero);
|
||||
Value sizeForNegativeOffset = rewriter.create<arith::MaxSIOp>(
|
||||
loc,
|
||||
rewriter.create<arith::MinSIOp>(
|
||||
loc, rewriter.create<arith::AddIOp>(loc, dim1Size, indexOffset),
|
||||
dim2Size),
|
||||
indexZero);
|
||||
Value sizeForPositiveOffset = rewriter.create<arith::MaxSIOp>(
|
||||
loc,
|
||||
rewriter.create<arith::MinSIOp>(
|
||||
loc, rewriter.create<arith::SubIOp>(loc, dim2Size, indexOffset),
|
||||
dim1Size),
|
||||
indexZero);
|
||||
Value diagSize = rewriter.create<arith::SelectOp>(
|
||||
loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset);
|
||||
|
||||
// depending on its sign, the offset affects only the row or column indices
|
||||
// of the diagonal
|
||||
Value diagStart1 = rewriter.create<arith::SelectOp>(
|
||||
loc, offsetIsNegative,
|
||||
rewriter.create<arith::MulIOp>(loc, indexOffset, indexMinusOne),
|
||||
indexZero);
|
||||
Value diagStart2 = rewriter.create<arith::SelectOp>(loc, offsetIsNegative,
|
||||
indexZero, indexOffset);
|
||||
|
||||
SmallVector<Value> outputDims;
|
||||
for (auto i = 0; i < inputRank; i++) {
|
||||
if (!(i == dim1 || i == dim2))
|
||||
outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i));
|
||||
}
|
||||
outputDims.push_back(diagSize);
|
||||
|
||||
Value outputMatrix = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outputDims), elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outputRank, utils::IteratorType::parallel);
|
||||
|
||||
auto diagonal =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputMatrix.getType(), ValueRange{}, outputMatrix,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> diagIndices;
|
||||
Value indexOnDiag =
|
||||
b.create<linalg::IndexOp>(loc, outputRank - 1);
|
||||
Value dim1Index =
|
||||
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart1);
|
||||
Value dim2Index =
|
||||
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart2);
|
||||
|
||||
// specify at which input indices the diagonal values are
|
||||
// extracted
|
||||
for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) {
|
||||
if (indIn == dim1)
|
||||
diagIndices.push_back(dim1Index);
|
||||
else if (indIn == dim2)
|
||||
diagIndices.push_back(dim2Index);
|
||||
else {
|
||||
diagIndices.push_back(
|
||||
b.create<linalg::IndexOp>(loc, indOut));
|
||||
indOut++;
|
||||
}
|
||||
}
|
||||
Value diagElt = b.create<tensor::ExtractOp>(
|
||||
loc, elementType, inputMatrix, diagIndices);
|
||||
b.create<linalg::YieldOp>(loc, diagElt);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, diagonal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -1872,4 +2008,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewAsRealOp>();
|
||||
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDiagonalOp>();
|
||||
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -6238,6 +6238,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: input must have at least two dimensions\"\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg2, %2, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
|
||||
" %4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %5 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg3, %4, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
|
||||
" %6 = torch.aten.ne.int %3, %5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %9 = torch.prim.ListConstruct %int9223372036854775807, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %10, %true, init() {\n"
|
||||
" ^bb0(%arg4: !torch.int):\n"
|
||||
" %19 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.eq.int %arg4, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %21 = torch.prim.If %20 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %22 = torch.aten.eq.int %arg4, %5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %22 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %21 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %22 = torch.aten.append.t %7, %19 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.sub.int %12, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.prim.min.int %11, %13 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.prim.max.int %14, %int0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
|
||||
" %19 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.add.int %19, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %21 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %22 = torch.prim.min.int %20, %21 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %23 = torch.prim.max.int %22, %int0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %23 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %15 : !torch.int\n"
|
||||
" }\n"
|
||||
" %18 = torch.aten.append.t %7, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" return %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9980,6 +10048,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.diagonal\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -222,7 +222,7 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
||||
AtenPixelShuffleOp>(op);
|
||||
AtenPixelShuffleOp, AtenDiagonalOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
|
|
|
@ -59,6 +59,36 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
|||
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim1=-1, dim2=-2, offset=1), # Positive `offset`.
|
||||
Invocation(TensorOfShape(2, 3, 4), offset=-1), # Negative `offset``.
|
||||
Invocation(TensorOfShape(2, 3, 4), offset=3), # Empty result due to large `offset`.
|
||||
ErrorInvocation(TensorOfShape(2)), # Input one-dimensional.
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal.
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=3, dim2=1), # `dim1` out of bounds.
|
||||
])
|
||||
def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> List[int]:
|
||||
assert len(self) >= 2, "input must have at least two dimensions"
|
||||
dim1 = upstream_shape_functions.maybe_wrap_dim(dim1, len(self))
|
||||
dim2 = upstream_shape_functions.maybe_wrap_dim(dim2, len(self))
|
||||
assert dim1 != dim2, "diagonal dimensions cannot be identical"
|
||||
|
||||
diagonal: List[int] = []
|
||||
for i, self_dim in enumerate(self):
|
||||
if (i==dim1) or (i==dim2):
|
||||
pass
|
||||
else:
|
||||
diagonal.append(self_dim)
|
||||
|
||||
diag_size = max(min(self[dim1], self[dim2] - offset), 0)
|
||||
if offset<0:
|
||||
diag_size = max(min(self[dim1] + offset, self[dim2]), 0)
|
||||
diagonal.append(diag_size)
|
||||
|
||||
return diagonal
|
||||
|
||||
def aten〇tan〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2493,6 +2523,11 @@ def aten〇tril〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) ->
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim1=0, dim2=1))
|
||||
def aten〇diagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -672,6 +672,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)")
|
||||
|
|
|
@ -60,3 +60,4 @@ def register_all_tests():
|
|||
from . import control_flow
|
||||
from . import stats
|
||||
from . import padding
|
||||
from . import diagonal
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalModule())
|
||||
def DiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalModule())
|
||||
def DiagonalModule_nonsquare(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalTransposedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=1, dim2=0)
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalTransposedModule())
|
||||
def DiagonalModule_transposed(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalWithDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithDimsModule())
|
||||
def DiagonalModule_with_dims(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalWithNegativeDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1)
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule())
|
||||
def DiagonalModule_with_negative_dims(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalWithOffsetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, offset=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithOffsetModule())
|
||||
def DiagonalModule_with_offset(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DiagonalWithDimsOffsetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1)
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
|
||||
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
Loading…
Reference in New Issue