diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index a3a98f0dd..23b573211 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -137,7 +137,7 @@ def export_and_import(f, *args, **kwargs): def sparse_jit(f, *args, **kwargs): """This method compiles and runs the given callable using linalg backend.""" # Import module and lower into Linalg IR. - module = export_and_import(f, *args, *kwargs) + module = export_and_import(f, *args, **kwargs) run_pipeline_with_repro_report( module, ( @@ -152,12 +152,19 @@ def sparse_jit(f, *args, **kwargs): backend = RefBackendLinalgOnTensorsBackend() compiled = backend.compile(module) invoker = backend.load(compiled) + xargs = [] + # Prepare the buffer parameters (assume all dense). + # TODO: filters out scalar arguments, anything else? + params = dict(f.named_buffers(remove_duplicate=True)) + params_flat, params_spec = torch.utils._pytree.tree_flatten(params) + for p in params_flat: + if len(p.shape) > 0: + xargs.append(p.numpy()) # Prepare input parameters. Sparse input tensors are split into # their composite tensors. All PyTorch tensors are converted # to their backing numpy arrays. Note that the output consists # of numpy arrays as well, which can trivially be reconstructed # into PyTorch tensors (dense and sparse). - xargs = [] for a in args: if a.layout is torch.sparse_coo: # Construct the additional position array required by MLIR with data