mirror of https://github.com/llvm/torch-mlir
Add ReflectionPad2d dynamic test
parent
5c0cf81905
commit
a2973b0f2f
|
@ -36,6 +36,29 @@ def ReflectionPad2dModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectionPad2dDynamicSizesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReflectionPad2dDynamicSizesModule())
|
||||||
|
def ReflectionPad2dDynamicSizesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 20, 20, low=-1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ReflectionPad2dModuleTop(torch.nn.Module):
|
class ReflectionPad2dModuleTop(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue