diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3034f239f..102fb2bbf 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6846,6 +6846,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n" @@ -10659,6 +10663,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmin\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 745cdd251..b7b2c2670 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -840,12 +840,13 @@ public: }; } // namespace -// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`. +// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp` namespace { -class DecomposeAtenArgMaxOp : public OpRewritePattern { +template +class DecomposeAtenArgMinMaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenArgmaxOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); @@ -870,7 +871,7 @@ public: .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. - // `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input + // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input // tensor is flattened to 1d tensor and then the reduction happens on the // 0th dimension. if (dim.getType().isa()) { @@ -885,13 +886,14 @@ public: input = rewriter.create(loc, flattenType, input, dim, end); } - Value maxResult = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, keepDim) - .getIndices(); - rewriter.replaceOp(op, maxResult); + Value resultArg = + rewriter + .create(loc, valueTensorType, indicesTensorType, + input, dim, keepDim) + .getIndices(); + + rewriter.replaceOp(op, resultArg); return success(); } }; @@ -5774,7 +5776,8 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal>(patterns); + addPatternIfTargetOpIsIllegal>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 348f521d5..71b3a9d91 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 110599322..478c5a161 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -445,6 +445,10 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: + # There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here. + return upstream_shape_functions.argmax(self, dim, keepdim) + # TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, # making it impossible to add support for it using the current design of the shape library. def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: @@ -3254,7 +3258,10 @@ def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Li @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: - self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmin〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: return torch.int64 @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 5ef44560d..ee6878701 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -40,7 +40,6 @@ def register_all_tests(): from . import type_conversion from . import backprop from . import reduction - from . import argmax from . import matmul from . import reshape_like from . import scalar diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py deleted file mode 100644 index 098ed508b..000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py +++ /dev/null @@ -1,65 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import torch - -from torch_mlir_e2e_test.framework import TestUtils -from torch_mlir_e2e_test.registry import register_test_case -from torch_mlir_e2e_test.annotations import annotate_args, export - -# ============================================================================== - -class ArgmaxModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - - def forward(self, a): - return torch.argmax(a) - - -@register_test_case(module_factory=lambda: ArgmaxModule()) -def ArgmaxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4)) - -# ============================================================================== - -class ArgmaxWithDimModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.argmax(a, dim=1) - -@register_test_case(module_factory=lambda: ArgmaxWithDimModule()) -def ArgmaxModule_with_dim(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) - -# ============================================================================== - -class ArgmaxKeepDimsModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.argmax(a, 0, True) - -@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) -def ArgmaxModule_keepDim(module, tu: TestUtils): - module.forward(tu.rand(4, 6)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index dcef324fc..75e6eb261 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -755,6 +755,171 @@ def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== + +class ArgminModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmin(a) + + +@register_test_case(module_factory=lambda: ArgminModule()) +def ArgminModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ArgminIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmin(a) + + +@register_test_case(module_factory=lambda: ArgminIntModule()) +def ArgminIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + +@register_test_case(module_factory=lambda: ArgminIntModule()) +def ArgminIntModule_multiple_mins(module, tu: TestUtils): + # To cover the special case that the minimal value occurs more than once. + # The pytorch convention is here to consider the first occurence as the argmin. + module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + +# ============================================================================== + +class ArgminWithDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmin(a, dim=1) + +@register_test_case(module_factory=lambda: ArgminWithDimModule()) +def ArgminModule_with_dim(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ArgminKeepDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmin(a, 0, True) + +@register_test_case(module_factory=lambda: ArgminKeepDimsModule()) +def ArgminModule_keepDim(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + +class ArgmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxModule()) +def ArgmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ArgmaxIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxIntModule()) +def ArgmaxIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + +@register_test_case(module_factory=lambda: ArgmaxIntModule()) +def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils): + # To cover the special case that the maximal value occurs more than once. + # The pytorch convention is here to consider the first occurence as the argmax. + module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + +# ============================================================================== + +class ArgmaxWithDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmax(a, dim=1) + +@register_test_case(module_factory=lambda: ArgmaxWithDimModule()) +def ArgmaxModule_with_dim(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ArgmaxKeepDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmax(a, 0, True) + +@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) +def ArgmaxModule_keepDim(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + class ReduceL1NormModule(torch.nn.Module): def __init__(self): super().__init__()