[FX] Add broadcast test with dynamic dim (#3123)

This scenario was uncovered in a downstream test that failed with a
previous snapshot of torch-mlir. See
https://github.com/cruise-automation/mlir-tcp/actions/runs/8605480116/job/23581829102?pr=65.
```
  File "/home/runner/.cache/bazel/_bazel_runner/ce288f117ee4ca92dc028a6a28476a3d/sandbox/processwrapper-sandbox/2380/execroot/mlir-tcp/bazel-out/k8-opt-exec-2B5CBBC6/bin/test/AotCompile/broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic_torch_exporter.runfiles/pip_deps_torch_mlir/site-packages/torch_mlir/extras/fx_importer.py", line 969, in value_info_to_type
    raise NotImplementedError(
NotImplementedError: Could not deduce type from value info: tensor_meta=None, val=s1, sparsity=None
```
It seems to have resolved on current HEAD. Adding this test to ensure
coverage in the future.
pull/3261/head
Sambhav Jain 2024-04-29 09:21:12 -07:00 committed by GitHub
parent 0a5ff68d9d
commit 2176176fef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 1 deletions

View File

@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
print(m) print(m)
@run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
dim_0 = Dim("dim_0")
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
)
print(m)
@make_boxed_compiler @make_boxed_compiler
def fx_import_aot_autograd_backend( def fx_import_aot_autograd_backend(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
@run @run
# CHECK-LABEL: test_stateless_fx_import # CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import(): def test_stateless_fx_import():