mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] pre-pend named buffers to parameter list (#3178)
weights and biases and other model parameters appear as a separate data structure to the traced graph, but are needed when running the MLIR compiled code; this PR implements that extended functionalitypull/3183/head
parent
b66eabd492
commit
491f4820f5
|
@ -137,7 +137,7 @@ def export_and_import(f, *args, **kwargs):
|
|||
def sparse_jit(f, *args, **kwargs):
|
||||
"""This method compiles and runs the given callable using linalg backend."""
|
||||
# Import module and lower into Linalg IR.
|
||||
module = export_and_import(f, *args, *kwargs)
|
||||
module = export_and_import(f, *args, **kwargs)
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
(
|
||||
|
@ -152,12 +152,19 @@ def sparse_jit(f, *args, **kwargs):
|
|||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(module)
|
||||
invoker = backend.load(compiled)
|
||||
xargs = []
|
||||
# Prepare the buffer parameters (assume all dense).
|
||||
# TODO: filters out scalar arguments, anything else?
|
||||
params = dict(f.named_buffers(remove_duplicate=True))
|
||||
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
||||
for p in params_flat:
|
||||
if len(p.shape) > 0:
|
||||
xargs.append(p.numpy())
|
||||
# Prepare input parameters. Sparse input tensors are split into
|
||||
# their composite tensors. All PyTorch tensors are converted
|
||||
# to their backing numpy arrays. Note that the output consists
|
||||
# of numpy arrays as well, which can trivially be reconstructed
|
||||
# into PyTorch tensors (dense and sparse).
|
||||
xargs = []
|
||||
for a in args:
|
||||
if a.layout is torch.sparse_coo:
|
||||
# Construct the additional position array required by MLIR with data
|
||||
|
|
Loading…
Reference in New Issue