Add a decomposition for torch.aten.argmin (#2613)

Adds a lowering for the torch.aten.argmin operator to linalg via decomposition into torch.aten.min.dim.

---------

Co-authored-by: Franz Haniel <franz.haniel@amd.com>
pull/2616/head
frafranz 2023-12-06 15:45:30 +01:00 committed by GitHub
parent 6244f301fb
commit c0115706a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 79 deletions

View File

@ -6846,6 +6846,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n"
@ -10659,6 +10663,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.argmin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %int0 = torch.constant.int 0\n"

View File

@ -840,12 +840,13 @@ public:
};
} // namespace
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp`
namespace {
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
template <typename OpTy, typename DecompOpTy>
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArgmaxOp op,
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
@ -870,7 +871,7 @@ public:
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input
// tensor is flattened to 1d tensor and then the reduction happens on the
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
@ -885,13 +886,14 @@ public:
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
}
Value maxResult =
rewriter
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.getIndices();
rewriter.replaceOp(op, maxResult);
Value resultArg =
rewriter
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.getIndices();
rewriter.replaceOp(op, resultArg);
return success();
}
};
@ -5774,7 +5776,8 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);

View File

@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();

View File

@ -445,6 +445,10 @@ def atenstdcorrection〡shape(self: List[int], dim: Optional[List[int]] =
def atenargmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)
def atenargmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
# There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here.
return upstream_shape_functions.argmax(self, dim, keepdim)
# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
# making it impossible to add support for it using the current design of the shape library.
def atenone_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
@ -3254,7 +3258,10 @@ def atenmeandim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Li
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenargmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenargmin〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
return torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))

View File

@ -40,7 +40,6 @@ def register_all_tests():
from . import type_conversion
from . import backprop
from . import reduction
from . import argmax
from . import matmul
from . import reshape_like
from . import scalar

View File

@ -1,65 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a)
@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, dim=1)
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, 0, True)
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

View File

@ -755,6 +755,171 @@ def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100))
# ==============================================================================
class ArgminModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a)
@register_test_case(module_factory=lambda: ArgminModule())
def ArgminModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ArgminIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a)
@register_test_case(module_factory=lambda: ArgminIntModule())
def ArgminIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))
@register_test_case(module_factory=lambda: ArgminIntModule())
def ArgminIntModule_multiple_mins(module, tu: TestUtils):
# To cover the special case that the minimal value occurs more than once.
# The pytorch convention is here to consider the first occurence as the argmin.
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))
# ==============================================================================
class ArgminWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a, dim=1)
@register_test_case(module_factory=lambda: ArgminWithDimModule())
def ArgminModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ArgminKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a, 0, True)
@register_test_case(module_factory=lambda: ArgminKeepDimsModule())
def ArgminModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))
# ==============================================================================
class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a)
@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ArgmaxIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a)
@register_test_case(module_factory=lambda: ArgmaxIntModule())
def ArgmaxIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))
@register_test_case(module_factory=lambda: ArgmaxIntModule())
def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils):
# To cover the special case that the maximal value occurs more than once.
# The pytorch convention is here to consider the first occurence as the argmax.
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))
# ==============================================================================
class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a, dim=1)
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a, 0, True)
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))
# ==============================================================================
class ReduceL1NormModule(torch.nn.Module):
def __init__(self):
super().__init__()