mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] pre-pend named buffers to parameter list (#3178)
weights and biases and other model parameters appear as a separate data structure to the traced graph, but are needed when running the MLIR compiled code; this PR implements that extended functionalitypull/3183/head
parent
b66eabd492
commit
491f4820f5
|
@ -137,7 +137,7 @@ def export_and_import(f, *args, **kwargs):
|
||||||
def sparse_jit(f, *args, **kwargs):
|
def sparse_jit(f, *args, **kwargs):
|
||||||
"""This method compiles and runs the given callable using linalg backend."""
|
"""This method compiles and runs the given callable using linalg backend."""
|
||||||
# Import module and lower into Linalg IR.
|
# Import module and lower into Linalg IR.
|
||||||
module = export_and_import(f, *args, *kwargs)
|
module = export_and_import(f, *args, **kwargs)
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
module,
|
module,
|
||||||
(
|
(
|
||||||
|
@ -152,12 +152,19 @@ def sparse_jit(f, *args, **kwargs):
|
||||||
backend = RefBackendLinalgOnTensorsBackend()
|
backend = RefBackendLinalgOnTensorsBackend()
|
||||||
compiled = backend.compile(module)
|
compiled = backend.compile(module)
|
||||||
invoker = backend.load(compiled)
|
invoker = backend.load(compiled)
|
||||||
|
xargs = []
|
||||||
|
# Prepare the buffer parameters (assume all dense).
|
||||||
|
# TODO: filters out scalar arguments, anything else?
|
||||||
|
params = dict(f.named_buffers(remove_duplicate=True))
|
||||||
|
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
||||||
|
for p in params_flat:
|
||||||
|
if len(p.shape) > 0:
|
||||||
|
xargs.append(p.numpy())
|
||||||
# Prepare input parameters. Sparse input tensors are split into
|
# Prepare input parameters. Sparse input tensors are split into
|
||||||
# their composite tensors. All PyTorch tensors are converted
|
# their composite tensors. All PyTorch tensors are converted
|
||||||
# to their backing numpy arrays. Note that the output consists
|
# to their backing numpy arrays. Note that the output consists
|
||||||
# of numpy arrays as well, which can trivially be reconstructed
|
# of numpy arrays as well, which can trivially be reconstructed
|
||||||
# into PyTorch tensors (dense and sparse).
|
# into PyTorch tensors (dense and sparse).
|
||||||
xargs = []
|
|
||||||
for a in args:
|
for a in args:
|
||||||
if a.layout is torch.sparse_coo:
|
if a.layout is torch.sparse_coo:
|
||||||
# Construct the additional position array required by MLIR with data
|
# Construct the additional position array required by MLIR with data
|
||||||
|
|
Loading…
Reference in New Issue