[torch-mlir][sparse] pre-pend named buffers to parameter list (#3178)

weights and biases and other model parameters appear as a separate data
structure to the traced graph, but are needed when running the MLIR
compiled code; this PR implements that extended functionality
pull/3183/head
Aart Bik 2024-04-17 14:44:05 -07:00 committed by GitHub
parent b66eabd492
commit 491f4820f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 2 deletions

View File

@ -137,7 +137,7 @@ def export_and_import(f, *args, **kwargs):
def sparse_jit(f, *args, **kwargs): def sparse_jit(f, *args, **kwargs):
"""This method compiles and runs the given callable using linalg backend.""" """This method compiles and runs the given callable using linalg backend."""
# Import module and lower into Linalg IR. # Import module and lower into Linalg IR.
module = export_and_import(f, *args, *kwargs) module = export_and_import(f, *args, **kwargs)
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
module, module,
( (
@ -152,12 +152,19 @@ def sparse_jit(f, *args, **kwargs):
backend = RefBackendLinalgOnTensorsBackend() backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module) compiled = backend.compile(module)
invoker = backend.load(compiled) invoker = backend.load(compiled)
xargs = []
# Prepare the buffer parameters (assume all dense).
# TODO: filters out scalar arguments, anything else?
params = dict(f.named_buffers(remove_duplicate=True))
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
for p in params_flat:
if len(p.shape) > 0:
xargs.append(p.numpy())
# Prepare input parameters. Sparse input tensors are split into # Prepare input parameters. Sparse input tensors are split into
# their composite tensors. All PyTorch tensors are converted # their composite tensors. All PyTorch tensors are converted
# to their backing numpy arrays. Note that the output consists # to their backing numpy arrays. Note that the output consists
# of numpy arrays as well, which can trivially be reconstructed # of numpy arrays as well, which can trivially be reconstructed
# into PyTorch tensors (dense and sparse). # into PyTorch tensors (dense and sparse).
xargs = []
for a in args: for a in args:
if a.layout is torch.sparse_coo: if a.layout is torch.sparse_coo:
# Construct the additional position array required by MLIR with data # Construct the additional position array required by MLIR with data