mirror of https://github.com/llvm/torch-mlir
Add test with pytorch fix coming in
parent
443b5cd2e1
commit
7cf1bb9a93
|
@ -8183,6 +8183,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %15 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %14 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
||||
|
@ -12133,6 +12194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -1059,6 +1059,15 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
|
|||
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
return maxpool3d, indices
|
||||
|
||||
def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
||||
assert (len(self) == 4 or len(self) == 3), "Input be of rank 3 or 4"
|
||||
assert (len(output_size) == 2), "output_size must have 2 elements"
|
||||
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
|
||||
if len(self) == 4:
|
||||
return [self[0], self[1], output_size[0], output_size[1]]
|
||||
else:
|
||||
return [self[0], output_size[0], output_size[1]]
|
||||
|
||||
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
||||
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
||||
assert (len(output_size) == 3), "output_size must have 3 elements"
|
||||
|
@ -3205,6 +3214,10 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype, torch.int64
|
||||
|
||||
def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
|
|
@ -1988,6 +1988,36 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxUnpool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, 2, 2], torch.float32, True),
|
||||
([-1, -1, 2, 2], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, indices):
|
||||
return torch.ops.aten.max_unpool2d(x, indices, (4, 4), (2, 2), (0, 0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxUnpool2dModule())
|
||||
def MaxUnpool2dModule_basic(module, tu: TestUtils):
|
||||
input = tu.rand(2, 2, 4, 4)
|
||||
pool = torch.nn.MaxPool2d(
|
||||
kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), return_indices=True
|
||||
)
|
||||
output, indices = pool(input)
|
||||
|
||||
module.forward(output, indices)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxUnpool3dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue