mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add decomposition features to sparse compiler (#3505)
Fixes https://github.com/llvm/torch-mlir/issues/3499pull/3519/head
parent
af236dab66
commit
6fece25ff3
|
@ -49,6 +49,7 @@ DEFAULT_DECOMPOSITIONS = [
|
|||
torch.ops.aten.nan_to_num.default,
|
||||
torch.ops.aten.unbind,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
||||
torch.ops.aten.diag,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import torch.export
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir.extras.fx_importer import SparsityMeta
|
||||
from torch_mlir import ir
|
||||
|
@ -106,6 +107,9 @@ def sparse_export(
|
|||
# Build the regular FX traced graph with only dense arguments
|
||||
# (the current version would crash otherwise, see issue above).
|
||||
prog = torch.export.export(f, dargs, kwargs)
|
||||
decomposition_table = get_decomposition_table()
|
||||
if decomposition_table:
|
||||
prog = prog.run_decompositions(decomposition_table)
|
||||
# Annotate sparse arguments in the graph and apply some very
|
||||
# basic propagation rules for sparsity.
|
||||
specs = prog.graph_signature.input_specs
|
||||
|
@ -120,7 +124,6 @@ def sparse_export(
|
|||
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||
k = k + 1
|
||||
elif node.op == "call_function":
|
||||
# TODO: use upstream _opname implementation when available
|
||||
opname = node.target._schema.name.split("::")[1]
|
||||
# Zero preserving elt-wise unary op.
|
||||
if opname in {"abs", "neg", "relu", "sin"}:
|
||||
|
@ -131,7 +134,7 @@ def sparse_export(
|
|||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
# TODO: Uncomment this to hack sparsity into the network.
|
||||
# elif opname == "_to_dense":
|
||||
# elif opname == "_to_dense" or opname == "to_dense":
|
||||
# # hack (assumes we never really want the to_dense for now)
|
||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
||||
|
@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs):
|
|||
compiled = backend.compile(module)
|
||||
invoker = backend.load(compiled)
|
||||
xargs = []
|
||||
# Prepare the buffer parameters (assume all dense).
|
||||
# TODO: filters out scalar arguments, anything else?
|
||||
# Prepare all the named buffer parameters (assume all dense).
|
||||
# All scalar arguments are filtered out since they appear inline.
|
||||
params = dict(f.named_buffers(remove_duplicate=True))
|
||||
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
||||
for p in params_flat:
|
||||
|
@ -339,6 +342,7 @@ def test_sparse_SpMV():
|
|||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_SpMM
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
||||
|
@ -440,7 +444,7 @@ def test_sparse_eltwise():
|
|||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
# TODO: make this work
|
||||
# TODO: make this work in MLIR
|
||||
# res3 = sparse_jit(net, batch_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
|
@ -657,7 +661,14 @@ def test_sparse_network():
|
|||
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
||||
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
|
||||
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
|
||||
#
|
||||
# TODO: first row looks suspect...
|
||||
#
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[}}[0. 0. 0. 0. ]
|
||||
# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418]
|
||||
# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ]
|
||||
# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
|
||||
#
|
||||
def test_sparse_feature_scaling():
|
||||
class Scale(nn.Module):
|
||||
|
@ -678,11 +689,11 @@ def test_sparse_feature_scaling():
|
|||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(f)
|
||||
# TODO: make this work
|
||||
# res2 = sparse_jit(net, f)
|
||||
res2 = sparse_jit(net, f)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
||||
|
||||
@run
|
||||
|
|
Loading…
Reference in New Issue