Add test with pytorch fix coming in

pull/3733/head
jinchen62 2024-10-24 12:45:17 -07:00
parent 443b5cd2e1
commit 7cf1bb9a93
3 changed files with 108 additions and 0 deletions

View File

@ -8183,6 +8183,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n"
" %true = torch.constant.bool true\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %15 : !torch.list<int>\n"
" } else {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %14 : !torch.list<int>\n"
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
@ -12133,6 +12194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -1059,6 +1059,15 @@ def atenmax_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
return maxpool3d, indices
def atenmax_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 4 or len(self) == 3), "Input be of rank 3 or 4"
assert (len(output_size) == 2), "output_size must have 2 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
if len(self) == 4:
return [self[0], self[1], output_size[0], output_size[1]]
else:
return [self[0], output_size[0], output_size[1]]
def atenmax_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
assert (len(output_size) == 3), "output_size must have 3 elements"
@ -3205,6 +3214,10 @@ def atenmax_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64
def atenmax_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
def atenmax_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

View File

@ -1988,6 +1988,36 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
# ==============================================================================
class MaxUnpool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, 2, 2], torch.float32, True),
([-1, -1, 2, 2], torch.int64, True),
]
)
def forward(self, x, indices):
return torch.ops.aten.max_unpool2d(x, indices, (4, 4), (2, 2), (0, 0))
@register_test_case(module_factory=lambda: MaxUnpool2dModule())
def MaxUnpool2dModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 4)
pool = torch.nn.MaxPool2d(
kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), return_indices=True
)
output, indices = pool(input)
module.forward(output, indices)
# ==============================================================================
class MaxUnpool3dModule(torch.nn.Module):
def __init__(self):
super().__init__()