mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] improve argmax/argmin's decomposition to support keep… (#3514)
…dim=True when dim=Nonepull/3518/head
parent
2f231f394e
commit
e2fbded49c
|
@ -7313,11 +7313,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.patched_argmax_shape_func(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %arg2 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %4, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %5 = torch.aten.append.t %3, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
|
@ -7372,19 +7399,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
|
||||
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
|
|
|
@ -1920,15 +1920,19 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value dim = op.getDim();
|
||||
Value keepDim = op.getKeepdim();
|
||||
Value result = op.getResult();
|
||||
|
||||
bool keepDim;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected keepdim to be a constant bool");
|
||||
}
|
||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||
BaseTensorType indicesTensorType = cast<BaseTensorType>(result.getType());
|
||||
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
||||
if (!maybeInputRank) {
|
||||
if (!maybeInputRank || *maybeInputRank == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a rank");
|
||||
op, "expected input tensor to have a rank > 0");
|
||||
}
|
||||
unsigned inputRank = *maybeInputRank;
|
||||
if (!indicesTensorType.hasSizes())
|
||||
|
@ -1945,21 +1949,49 @@ public:
|
|||
BaseTensorType flattenType =
|
||||
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||
{kUnknownSize}, inputType.getOptionalDtype()));
|
||||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||||
dim, end);
|
||||
zero, end);
|
||||
Value resultIndices =
|
||||
rewriter
|
||||
.create<DecompOpTy>(
|
||||
loc,
|
||||
valueTensorType.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>{}, valueTensorType.getOptionalDtype()),
|
||||
indicesTensorType.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>{},
|
||||
indicesTensorType.getOptionalDtype()),
|
||||
input, /*dim=*/zero, /*keepdim=*/falseValue)
|
||||
.getIndices();
|
||||
if (keepDim) {
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc,
|
||||
Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||
SmallVector<Value>(inputRank, one));
|
||||
resultIndices = rewriter.create<AtenReshapeOp>(
|
||||
loc,
|
||||
indicesTensorType.getWithSizesAndDtype(
|
||||
SmallVector<int64_t>(inputRank, 1),
|
||||
indicesTensorType.getOptionalDtype()),
|
||||
resultIndices, dimList);
|
||||
}
|
||||
rewriter.replaceOp(op, resultIndices);
|
||||
return success();
|
||||
} else {
|
||||
Value resultIndices =
|
||||
rewriter
|
||||
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
|
||||
input, dim, op.getKeepdim())
|
||||
.getIndices();
|
||||
rewriter.replaceOp(op, resultIndices);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value resultArg =
|
||||
rewriter
|
||||
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType, input,
|
||||
dim, keepDim)
|
||||
.getIndices();
|
||||
|
||||
rewriter.replaceOp(op, resultArg);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -1505,6 +1505,7 @@ STABLEHLO_CRASHING_SET = {"IndexPutWithNoneAndBroadcastModule_basic"}
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ArgmaxKeepdimModule_basic",
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
|
|
|
@ -680,8 +680,19 @@ def aten〇trace〡shape(self: List[int]) -> List[int]:
|
|||
assert len(self) == 2, "input must have rank 2"
|
||||
return []
|
||||
|
||||
# TODO: replace this patched function with `upstream_shape_functions.argmax` when upstream fix it
|
||||
# see https://github.com/pytorch/pytorch/pull/129838
|
||||
def patched_argmax_shape_func(self: List[int], dim: Optional[int] = None, keepdim: bool = False):
|
||||
if dim is None and keepdim:
|
||||
out: List[int] = []
|
||||
for i in self:
|
||||
out.append(1)
|
||||
return out
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
|
||||
|
@ -690,11 +701,11 @@ def aten〇trace〡shape(self: List[int]) -> List[int]:
|
|||
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
|
||||
])
|
||||
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return patched_argmax_shape_func(self, dim, keepdim)
|
||||
|
||||
def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
# There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here.
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return patched_argmax_shape_func(self, dim, keepdim)
|
||||
|
||||
# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
|
||||
# making it impossible to add support for it using the current design of the shape library.
|
||||
|
@ -722,12 +733,19 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa
|
|||
def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`.
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds.
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
|
||||
])
|
||||
def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||
if dim is None:
|
||||
return [], []
|
||||
else:
|
||||
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return reduced_shape, reduced_shape
|
||||
reduced_shape = patched_argmax_shape_func(self, dim, keepdim)
|
||||
return reduced_shape, reduced_shape
|
||||
|
||||
def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
|
|
@ -1533,6 +1533,29 @@ def ArgmaxModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ArgmaxKeepdimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.argmax(a, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxKeepdimModule())
|
||||
def ArgmaxKeepdimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArgmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue