From 1fcbfa87ec9ab2810bfaacd6c534f3ee0b053abd Mon Sep 17 00:00:00 2001 From: schnkmwt <152340442+schnkmwt@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:32:50 -0700 Subject: [PATCH] Implement linalg lowering of diag_embed torch op (#2885) This PR adds lowering of diag_embed to linalg dilect. Tracked in https://github.com/nod-ai/SHARK-Turbine/issues/288 --------- Co-authored-by: sachink --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 156 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 89 ++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 45 +++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/constant_alloc.py | 115 +++++++++++++ 7 files changed, 438 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c683723d9..808f329a0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8429,6 +8429,32 @@ def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [ }]; } +def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::diag_embed : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiagEmbedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenDiagEmbedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index e4bf1886b..d2953ec22 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" @@ -2094,6 +2095,159 @@ public: }; } // namespace +namespace { +class ConvertAtenDiagEmbedOp : public OpConversionPattern { + + static SmallVector + getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, + int64_t offset, int64_t dim1, int64_t dim2) { + auto inputType = tensor.getType().cast(); + auto inputRank = inputType.getRank(); + + // output tensor always has 1 extra dimension + auto resultRank = inputRank + 1; + + // regardless of offset sign, output tensor is same + Value constOffset = b.create(loc, offset); + Value absOffset = b.create(loc, constOffset); + + // diagonal size is determined by last input dimension + auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); + Value diagDim = b.create(loc, lastInputDim, absOffset); + + // output shape has same dimensions as input + // except for the diagonal dimensions + int input_dim_idx = 0; + SmallVector resultShape; + for (unsigned int i = 0; i < resultRank; i++) { + if (i == dim1 || i == dim2) + resultShape.push_back(diagDim); + else + resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++)); + } + + return resultShape; + } + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + + Value input = adaptor.getSelf(); + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + auto resultRank = inputRank + 1; + + int64_t offset; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, "offset is not constant"); + + int64_t dim1; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, "dim1 is not constant"); + dim1 = toPositiveDim(dim1, resultRank); + if (!isValidDim(dim1, resultRank)) + return rewriter.notifyMatchFailure( + op, "dim1 can only be in closed range [" + + std::to_string(-resultRank) + "," + + std::to_string(resultRank - 1) + "]"); + + int64_t dim2; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, "dim2 is not constant"); + dim2 = toPositiveDim(dim2, resultRank); + if (!isValidDim(dim2, resultRank)) + return rewriter.notifyMatchFailure( + op, "dim2 can only be in closed range [" + + std::to_string(-resultRank) + "," + + std::to_string(resultRank - 1) + "]"); + + if (dim1 == dim2) + return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal"); + + // add linalg.fill + Type resultElemType = inputType.getElementType(); + auto resultShape = + getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, resultElemType); + + // add linalg.generic with diagonal access pattern affine indexing maps + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(resultRank), + }; + SmallVector iteratorTypes( + resultRank, utils::IteratorType::parallel); + Value resultTensor = + rewriter + .create( + loc, zeroTensor.getType(), ValueRange{}, zeroTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim1Index = b.create(loc, dim1); + Value dim2Index = b.create(loc, dim2); + + // to pick right element from input, first add all dimensions + // except last one, then last will be either dim1 or dim2 + // depending upon lower or upper diagonal defined by offset + // sign + SmallVector inputIndices; + for (unsigned int i = 0; i < resultRank; i++) { + if (i != dim1 && i != dim2) { + inputIndices.push_back(b.create(loc, i)); + } + } + + // adjust output diagonal indices and last input Index based + // on offset + Value dim1IdxAdjusted; + Value dim2IdxAdjusted; + if (offset < 0) { + Value absOffset = + b.create(loc, -offset); + dim1IdxAdjusted = dim1Index; + dim2IdxAdjusted = + b.create(loc, dim2Index, absOffset); + inputIndices.push_back( + b.create(loc, dim2)); + } else { + Value constOffset = + b.create(loc, offset); + dim1IdxAdjusted = + b.create(loc, dim1Index, constOffset); + dim2IdxAdjusted = dim2Index; + inputIndices.push_back( + b.create(loc, dim1)); + } + + Value isDiagonal = + b.create(loc, arith::CmpIPredicate::eq, + dim1IdxAdjusted, dim2IdxAdjusted); + + Value inputElem = b.create( + loc, resultElemType, input, inputIndices); + + Value result = rewriter.create( + loc, isDiagonal, inputElem, args[0]); + b.create(loc, result); + }) + .getResult(0); + + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + rewriter.replaceOpWithNewOp(op, resultType, resultTensor); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2136,4 +2290,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 6bd8e797a..38df3af98 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8253,6 +8253,91 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.ne.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.lt.int %arg2, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n" +" %5 = torch.aten.ge.int %arg2, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.lt.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n" +" %8 = torch.aten.ge.int %arg3, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %1, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %11 = torch.aten.lt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %1, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.int\n" +" }\n" +" %13 = torch.prim.ListConstruct : () -> !torch.list\n" +" %14 = torch.prim.Loop %1, %true, init(%int0) {\n" +" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" +" %15 = torch.prim.ListConstruct %10, %12 : (!torch.int, !torch.int) -> !torch.list\n" +" %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int \n" +" %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %13, %20 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %arg5 : !torch.int\n" +" } else {\n" +" %18 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %13, %18 : !torch.list, !torch.int -> !torch.list\n" +" %20 = torch.aten.add.int %arg5, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %13 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12516,6 +12601,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.diag_embed\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index afadb52b1..55b7d8883 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1878,6 +1878,12 @@ ONNX_XFAIL_SET = { "DiagonalModule_with_dims_and_offset", "DiagonalModule_with_negative_dims", "DiagonalModule_with_offset", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceFloatProdModuleIncludeSelf", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 42dfeb533..10be100db 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -53,6 +53,32 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape +def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int): + self_rank = len(self) + result_rank = self_rank + 1 + + assert dim1 != dim2 + assert dim1 < result_rank + assert dim1 >= -(result_rank) + assert dim2 < result_rank + assert dim2 >= -(result_rank) + + if dim1 < 0: + dim1 = result_rank + dim1 + if dim2 < 0: + dim2 = result_rank + dim2 + + result_shape: List[int] = [] + input_dim_idx = 0 + for i in range(result_rank): + if i in (dim1, dim2): + result_shape.append(self[-1] + abs(offset)) + else: + result_shape.append(self[input_dim_idx]) + input_dim_idx += 1 + + return result_shape + def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) @@ -1057,6 +1083,20 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=3), # Test explicit dim1 and dim2. + Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=1, dim2=3), # Positive offset. + Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=3, dim2=1), # Reverse dim1 and dim2 + Invocation(TensorOfShape(2, 3, 4), offset=-1, dim1=1, dim2=3), # Negative offset + Invocation(TensorOfShape(2, 3, 4), offset=3), # large `offset`. + Invocation(TensorOfShape(2)), # Input one-dimensional. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=4, dim2=1), # `dim1` out of bounds. +]) +def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]: + return _diag_embed_shape_helper(self, offset, dim1, dim2) + def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -4200,6 +4240,11 @@ def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f5a442718..92746f8d2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -561,6 +561,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)") + emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index eb0143b9d..540fa2d22 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1873,3 +1873,118 @@ class EmptyStridedSizeIntStrideModule(torch.nn.Module): @register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule()) def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + + +class AtenDiagEmbedDefaultDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedDefaultDiag()) + def AtenDiagEmbedDefaultDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedDimDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=0, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedDimDiag()) + def AtenDiagEmbedDimDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedOffsetDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=1, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedOffsetDiag()) + def AtenDiagEmbedOffsetDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedRevDimDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=1, dim1=3, dim2=1) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedRevDimDiag()) + def AtenDiagEmbedRevDimDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedNegOffsetDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=-1, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedNegOffsetDiag()) + def AtenDiagEmbedNegOffsetDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +class AtenDiagEmbedNonDefault4DDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=-2, dim1=1, dim2=-3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag()) + def AtenDiagEmbedNonDefault4DDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5)) \ No newline at end of file