From c7d972ed580324c61a73607530f52d7918dd86c2 Mon Sep 17 00:00:00 2001 From: Branko Trifkovic <88882867+BaneTrifa@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:08:12 +0200 Subject: [PATCH] Implement lowering of torch.aten.tril_indices (#3517) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 +++ lib/Dialect/Torch/IR/TorchOps.cpp | 37 ++++ .../Transforms/AbstractInterpLibrary.cpp | 59 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 209 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 4 + .../build_tools/abstract_interp_lib_gen.py | 26 +++ .../build_tools/torch_ods_gen.py | 5 + .../test_suite/elementwise.py | 79 +++++++ 9 files changed, 450 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 626e259fe..964b045a9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15857,6 +15857,36 @@ def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [ let hasVerifier = 1; } +def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$row, + Torch_IntType:$col, + Torch_IntType:$offset, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTrilIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenTrilIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 53372006d..422883914 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5252,3 +5252,40 @@ LogicalResult AtenTriuIndicesOp::verify() { return success(); } + +// AtenTrilIndicesOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenTrilIndicesOp::verify() { + + // Check if row, col and offset are constant ints + int64_t row; + if (!matchPattern(getRow(), m_TorchConstantInt(&row))) + return success(); + + int64_t col; + if (!matchPattern(getCol(), m_TorchConstantInt(&col))) + return success(); + + int64_t offset; + if (!matchPattern(getOffset(), m_TorchConstantInt(&offset))) + return success(); + + // Check if values of row, and col are valid + if (row < 0) + return emitOpError("row must be non-negative, got ") << row; + + if (col < 0) + return emitOpError("col must be non-negative, got ") << col; + + // Check if dtype is valid + int64_t dtype; + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) + return success(); + if (dtype != (int)torch_upstream::ScalarType::Int && + dtype != (int)torch_upstream::ScalarType::Long) + return emitOpError( + "'triu_indices' implemented only for torch.int32 and torch.int64"); + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 218c6840d..90afe5ee3 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9993,6 +9993,53 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %21 = torch.aten.add.int %int1, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.prim.min.int %arg1, %21 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %22 : !torch.int\n" +" } else {\n" +" %21 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.gt.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.aten.Int.bool %22 : !torch.bool -> !torch.int\n" +" torch.prim.If.yield %23 : !torch.int\n" +" }\n" +" %5 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.min.int %arg1, %5 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.prim.max.int %int0, %6 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.prim.min.int %arg0, %8 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.max.int %int0, %9 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %7, %4 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.add.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %13, %12 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.floordiv.int %14, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.sub.int %10, %12 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.mul.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.prim.ListConstruct %int2, %19 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %20 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -14773,6 +14820,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2af330280..df044f52f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -992,6 +992,214 @@ public: }; } // namespace +// decomposition of torch.tril_indices +// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5797 +namespace { +class DecomposeAtenTrilIndicesOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTrilIndicesOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + // Required parameters + Value row = op.getRow(); + Value col = op.getCol(); + Value offset = op.getOffset(); + + // Check if row, col and offset are constant ints + int64_t rowInt; + if (!matchPattern(row, m_TorchConstantInt(&rowInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: row not constant int"); + + int64_t colInt; + if (!matchPattern(col, m_TorchConstantInt(&colInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: col not constant int"); + + int64_t offsetInt; + if (!matchPattern(offset, m_TorchConstantInt(&offsetInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: offset not constant int"); + + // Optional parameters + Value dtype = op.getDtype(); + Value layout = op.getLayout(); + Value device = op.getDevice(); + Value pinMemory = op.getPinMemory(); + + // Constants + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstFalse = rewriter.create(loc, false); + Value cstZeroPointFive = rewriter.create( + loc, rewriter.getF64FloatAttr(0.5)); + Value cstTwoFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + + // Get int value for dtype + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dtype not constant int"); + + FailureOr dtypeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(dtypeType)) + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + + // Calculte trapezoidSize, rectangleSize and mFirstRow + std::tuple triuSizes = + getTrilSizes(rowInt, colInt, offsetInt); + + int64_t trapezoidSizeInt = std::get<0>(triuSizes); + int64_t rectangleSizeInt = std::get<1>(triuSizes); + int64_t mFirstRowInt = std::get<2>(triuSizes); + + // Create const int Values from ints + Value trapezoidSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = rewriter.create( + loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + + // Calculte column offset + int64_t rowOffsetInt = (-offsetInt > 0) ? (-offsetInt) : 0; + Value rowOffset = rewriter.create(loc, rowOffsetInt); + + // First we do the indices for TOP trapezoid + auto f64DtypeInt = + getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); + auto arrangeType = + getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); + Value xs1 = + rewriter.create(loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // b = m_first_row - 0.5 + Value mFirstRowFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = + rewriter.create(loc, mFirstRowFloat, cstZeroPointFive); + + // Implements this piece of code: row_inds1 = torch.floor(-b + torch.sqrt(b + // * b + 2 * xs1)) + Value bSquare = rewriter.create(loc, b, b); + + Value twoTimesXs1 = + rewriter.create(loc, xs1.getType(), xs1, cstTwoFloat); + Value sqrtInput = rewriter.create( + loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + + Value sqrt = + rewriter.create(loc, sqrtInput.getType(), sqrtInput); + + Value rowInds1 = + rewriter.create(loc, sqrt.getType(), sqrt, b, cstOne); + rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + + // Implements this piece of code: col_inds1 = torch.floor(xs1 - (2 * + // m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + Value twoTimesMFirstRow = + rewriter.create(loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + rewriter.create(loc, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = rewriter.create( + loc, rowInds1.getType(), rowInds1, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = rewriter.create( + loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, rowInds1); + twoTimesMFirstRow = rewriter.create( + loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, cstZeroPointFive); + + Value colInds1 = rewriter.create( + loc, xs1.getType(), xs1, twoTimesMFirstRow, cstOne); + colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + + // Convert top trapezoid indices to dtype + Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); + + auto rowInds1Type = cast(rowInds1.getType()); + ArrayRef sizes = rowInds1Type.getSizes(); + Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); + rowInds1 = rewriter.create(loc, rowInds1.getType(), + rowInds1, rowOffset, cstOne); + rowInds1 = rewriter.create( + loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + auto colInds1Type = cast(colInds1.getType()); + sizes = colInds1Type.getSizes(); + Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); + colInds1 = rewriter.create( + loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + // Calculate indices for BOTTOM rectangle + arrangeType = getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); + Value xs2 = + rewriter.create(loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // Implements this line of code: row_inds2 = xs2 // col + (col - m_first_row + // + 1 + row_offset) + Value rowInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + int64_t addInt = colInt - mFirstRowInt + 1 + rowOffsetInt; + Value cstAdd = rewriter.create(loc, addInt); + rowInds2 = rewriter.create(loc, rowInds2.getType(), + rowInds2, cstAdd, cstOne); + + // Implements this line of code: col_inds2 = xs2 % col + Value colInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + + // Prepare tensors for concatenation + Type listElemType = + cast(rowInds1.getType()) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + Value sequenceRow = rewriter.create( + loc, listType, SmallVector{rowInds1, rowInds2}); + Value sequenceCol = rewriter.create( + loc, listType, SmallVector{colInds1, colInds2}); + + // Concatenate row and col indices + Type finalCatType = colInds1Type.getWithSizesAndDtype( + {rectangleSizeInt + trapezoidSizeInt}, int64Type); + + Value catRow = rewriter.create(loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = rewriter.create(loc, finalCatType, sequenceCol, + /*dim=*/cstZero); + + // Make return value - stack row and col indices + Value sequence = rewriter.create( + loc, Torch::ListType::get(context, rowInds1.getType()), + ValueRange{catRow, catCol}); + Type finalStackType = colInds1Type.getWithSizesAndDtype( + ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); + + rewriter.replaceOpWithNewOp(op, finalStackType, sequence, + cstZero); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -9063,6 +9271,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bbce3926e..3adb96d1f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -548,6 +548,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 90b01c804..1b173d3ec 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1337,6 +1337,10 @@ STABLEHLO_PASS_SET = { "TriuIndicesModule_basic", "TriuIndicesAllZerosModule_basic", "TriuIndicesNegativeOffsetModule_basic", + "TrilIndicesAllZerosModule_basic", + "TrilIndicesModule_basic", + "TrilIndicesNegativeOffsetModule_basic", + "TrilIndicesOfssetGreaterThanRowModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeAsSameModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e7b6a0efe..316851313 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1882,6 +1882,29 @@ def aten〇triu_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti return [2, triu_size] +@check_shape_function([ + Invocation(4, 3, 1), # Basic case. + Invocation(0, 0, 0), # All zeros case. + Invocation(7, 5, -2), # Negative offset case. + Invocation(35, 55, 16), # Largere values case. +]) +def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + if row == 0 or col == 0: + return [2, 0] + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return [2, trapezoid_size + rectangle_size] + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -5254,6 +5277,9 @@ def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> i def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: return torch.int64 if dtype is None else dtype +def aten〇tril_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.int64 if dtype is None else dtype + def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 07ab6dcc1..62ef59d50 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1090,6 +1090,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): has_verifier=True, ) + emit( + "aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)", + has_verifier=True, + ) + # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 7002cee43..82c77fee9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6364,3 +6364,82 @@ class TriuIndicesNegativeOffsetModule(torch.nn.Module): @register_test_case(module_factory=lambda: TriuIndicesNegativeOffsetModule()) def TriuIndicesNegativeOffsetModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class TrilIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(4, 3, 1) + + +@register_test_case(module_factory=lambda: TrilIndicesModule()) +def TrilIndicesModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesAllZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(0, 0, 0) + + +@register_test_case(module_factory=lambda: TrilIndicesAllZerosModule()) +def TrilIndicesAllZerosModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesNegativeOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(5, 16, -2) + + +@register_test_case(module_factory=lambda: TrilIndicesNegativeOffsetModule()) +def TrilIndicesNegativeOffsetModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(7, 9, 8) + + +@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) +def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): + module.forward()