mirror of https://github.com/llvm/torch-mlir
[FX] Add broadcast test with dynamic dim (#3123)
This scenario was uncovered in a downstream test that failed with a previous snapshot of torch-mlir. See https://github.com/cruise-automation/mlir-tcp/actions/runs/8605480116/job/23581829102?pr=65. ``` File "/home/runner/.cache/bazel/_bazel_runner/ce288f117ee4ca92dc028a6a28476a3d/sandbox/processwrapper-sandbox/2380/execroot/mlir-tcp/bazel-out/k8-opt-exec-2B5CBBC6/bin/test/AotCompile/broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic_torch_exporter.runfiles/pip_deps_torch_mlir/site-packages/torch_mlir/extras/fx_importer.py", line 969, in value_info_to_type raise NotImplementedError( NotImplementedError: Could not deduce type from value info: tensor_meta=None, val=s1, sparsity=None ``` It seems to have resolved on current HEAD. Adding this test to ensure coverage in the future.pull/3261/head
parent
0a5ff68d9d
commit
2176176fef
|
@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||
def test_broadcast_with_dynamic_shapes():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.broadcast_to(x, (y.shape[0], -1))
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
y = torch.randn(10)
|
||||
|
||||
dim_0 = Dim("dim_0")
|
||||
dynamic_shapes = {
|
||||
"x": {},
|
||||
"y": {0: dim_0},
|
||||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@make_boxed_compiler
|
||||
def fx_import_aot_autograd_backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
|
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_stateless_fx_import
|
||||
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
||||
def test_stateless_fx_import():
|
||||
|
|
Loading…
Reference in New Issue