mirror of https://github.com/llvm/torch-mlir
Add a decomposition for torch.aten.argmin (#2613)
Adds a lowering for the torch.aten.argmin operator to linalg via decomposition into torch.aten.min.dim. --------- Co-authored-by: Franz Haniel <franz.haniel@amd.com>pull/2616/head
parent
6244f301fb
commit
c0115706a0
|
@ -6846,6 +6846,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n"
|
||||
|
@ -10659,6 +10663,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.argmin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
|
|
|
@ -840,12 +840,13 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
|
||||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp`
|
||||
namespace {
|
||||
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
|
||||
template <typename OpTy, typename DecompOpTy>
|
||||
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenArgmaxOp op,
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
|
@ -870,7 +871,7 @@ public:
|
|||
.cast<BaseTensorType>();
|
||||
|
||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
|
||||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input
|
||||
// tensor is flattened to 1d tensor and then the reduction happens on the
|
||||
// 0th dimension.
|
||||
if (dim.getType().isa<Torch::NoneType>()) {
|
||||
|
@ -885,13 +886,14 @@ public:
|
|||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||||
dim, end);
|
||||
}
|
||||
Value maxResult =
|
||||
rewriter
|
||||
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
|
||||
input, dim, keepDim)
|
||||
.getIndices();
|
||||
|
||||
rewriter.replaceOp(op, maxResult);
|
||||
Value resultArg =
|
||||
rewriter
|
||||
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
|
||||
input, dim, keepDim)
|
||||
.getIndices();
|
||||
|
||||
rewriter.replaceOp(op, resultArg);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -5774,7 +5776,8 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
||||
|
|
|
@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenArangeOp>();
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
target.addIllegalOp<AtenArgminOp>();
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
target.addIllegalOp<AtenVarOp>();
|
||||
target.addIllegalOp<AtenStdOp>();
|
||||
|
|
|
@ -445,6 +445,10 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] =
|
|||
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
# There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here.
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
|
||||
# making it impossible to add support for it using the current design of the shape library.
|
||||
def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
|
||||
|
@ -3254,7 +3258,10 @@ def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Li
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return torch.int64
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇argmin〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
|
||||
return torch.int64
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
|
||||
|
|
|
@ -40,7 +40,6 @@ def register_all_tests():
|
|||
from . import type_conversion
|
||||
from . import backprop
|
||||
from . import reduction
|
||||
from . import argmax
|
||||
from . import matmul
|
||||
from . import reshape_like
|
||||
from . import scalar
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.argmax(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxModule())
|
||||
def ArgmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxWithDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, dim=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
|
||||
def ArgmaxModule_with_dim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxKeepDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, 0, True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
|
||||
def ArgmaxModule_keepDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
|
@ -755,6 +755,171 @@ def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgminModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmin(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgminModule())
|
||||
def ArgminModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgminIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmin(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgminIntModule())
|
||||
def ArgminIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgminIntModule())
|
||||
def ArgminIntModule_multiple_mins(module, tu: TestUtils):
|
||||
# To cover the special case that the minimal value occurs more than once.
|
||||
# The pytorch convention is here to consider the first occurence as the argmin.
|
||||
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgminWithDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmin(a, dim=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgminWithDimModule())
|
||||
def ArgminModule_with_dim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgminKeepDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmin(a, 0, True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgminKeepDimsModule())
|
||||
def ArgminModule_keepDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmax(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxModule())
|
||||
def ArgmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmax(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxIntModule())
|
||||
def ArgmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-100, high=100))
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxIntModule())
|
||||
def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils):
|
||||
# To cover the special case that the maximal value occurs more than once.
|
||||
# The pytorch convention is here to consider the first occurence as the argmax.
|
||||
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxWithDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmax(a, dim=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
|
||||
def ArgmaxModule_with_dim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArgmaxKeepDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmax(a, 0, True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
|
||||
def ArgmaxModule_keepDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL1NormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue