From 2176176fefd696d929b9d61b5587a419fae8386d Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 29 Apr 2024 09:21:12 -0700 Subject: [PATCH] [FX] Add broadcast test with dynamic dim (#3123) This scenario was uncovered in a downstream test that failed with a previous snapshot of torch-mlir. See https://github.com/cruise-automation/mlir-tcp/actions/runs/8605480116/job/23581829102?pr=65. ``` File "/home/runner/.cache/bazel/_bazel_runner/ce288f117ee4ca92dc028a6a28476a3d/sandbox/processwrapper-sandbox/2380/execroot/mlir-tcp/bazel-out/k8-opt-exec-2B5CBBC6/bin/test/AotCompile/broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic_torch_exporter.runfiles/pip_deps_torch_mlir/site-packages/torch_mlir/extras/fx_importer.py", line 969, in value_info_to_type raise NotImplementedError( NotImplementedError: Could not deduce type from value info: tensor_meta=None, val=s1, sparsity=None ``` It seems to have resolved on current HEAD. Adding this test to ensure coverage in the future. --- test/python/fx_importer/basic_test.py | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 08ef9fdc9..fde318630 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): print(m) +@run +# CHECK-LABEL: test_broadcast_with_dynamic_shapes +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +def test_broadcast_with_dynamic_shapes(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + dim_0 = Dim("dim_0") + dynamic_shapes = { + "x": {}, + "y": {0: dim_0}, + } + + m = fx.export_and_import( + Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + ) + print(m) + + @make_boxed_compiler def fx_import_aot_autograd_backend( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] @@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend( @run # CHECK-LABEL: test_stateless_fx_import -# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> def test_stateless_fx_import():