mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add decomposition features to sparse compiler (#3505)
Fixes https://github.com/llvm/torch-mlir/issues/3499pull/3519/head
parent
af236dab66
commit
6fece25ff3
|
@ -49,6 +49,7 @@ DEFAULT_DECOMPOSITIONS = [
|
||||||
torch.ops.aten.nan_to_num.default,
|
torch.ops.aten.nan_to_num.default,
|
||||||
torch.ops.aten.unbind,
|
torch.ops.aten.unbind,
|
||||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
||||||
|
torch.ops.aten.diag,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import torch.export
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||||
from torch_mlir.extras.fx_importer import FxImporter
|
from torch_mlir.extras.fx_importer import FxImporter
|
||||||
from torch_mlir.extras.fx_importer import SparsityMeta
|
from torch_mlir.extras.fx_importer import SparsityMeta
|
||||||
from torch_mlir import ir
|
from torch_mlir import ir
|
||||||
|
@ -106,6 +107,9 @@ def sparse_export(
|
||||||
# Build the regular FX traced graph with only dense arguments
|
# Build the regular FX traced graph with only dense arguments
|
||||||
# (the current version would crash otherwise, see issue above).
|
# (the current version would crash otherwise, see issue above).
|
||||||
prog = torch.export.export(f, dargs, kwargs)
|
prog = torch.export.export(f, dargs, kwargs)
|
||||||
|
decomposition_table = get_decomposition_table()
|
||||||
|
if decomposition_table:
|
||||||
|
prog = prog.run_decompositions(decomposition_table)
|
||||||
# Annotate sparse arguments in the graph and apply some very
|
# Annotate sparse arguments in the graph and apply some very
|
||||||
# basic propagation rules for sparsity.
|
# basic propagation rules for sparsity.
|
||||||
specs = prog.graph_signature.input_specs
|
specs = prog.graph_signature.input_specs
|
||||||
|
@ -120,7 +124,6 @@ def sparse_export(
|
||||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||||
k = k + 1
|
k = k + 1
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
# TODO: use upstream _opname implementation when available
|
|
||||||
opname = node.target._schema.name.split("::")[1]
|
opname = node.target._schema.name.split("::")[1]
|
||||||
# Zero preserving elt-wise unary op.
|
# Zero preserving elt-wise unary op.
|
||||||
if opname in {"abs", "neg", "relu", "sin"}:
|
if opname in {"abs", "neg", "relu", "sin"}:
|
||||||
|
@ -131,7 +134,7 @@ def sparse_export(
|
||||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||||
)
|
)
|
||||||
# TODO: Uncomment this to hack sparsity into the network.
|
# TODO: Uncomment this to hack sparsity into the network.
|
||||||
# elif opname == "_to_dense":
|
# elif opname == "_to_dense" or opname == "to_dense":
|
||||||
# # hack (assumes we never really want the to_dense for now)
|
# # hack (assumes we never really want the to_dense for now)
|
||||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||||
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
||||||
|
@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs):
|
||||||
compiled = backend.compile(module)
|
compiled = backend.compile(module)
|
||||||
invoker = backend.load(compiled)
|
invoker = backend.load(compiled)
|
||||||
xargs = []
|
xargs = []
|
||||||
# Prepare the buffer parameters (assume all dense).
|
# Prepare all the named buffer parameters (assume all dense).
|
||||||
# TODO: filters out scalar arguments, anything else?
|
# All scalar arguments are filtered out since they appear inline.
|
||||||
params = dict(f.named_buffers(remove_duplicate=True))
|
params = dict(f.named_buffers(remove_duplicate=True))
|
||||||
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
||||||
for p in params_flat:
|
for p in params_flat:
|
||||||
|
@ -339,6 +342,7 @@ def test_sparse_SpMV():
|
||||||
|
|
||||||
@run
|
@run
|
||||||
#
|
#
|
||||||
|
# CHECK-LABEL: test_sparse_SpMM
|
||||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||||
# CHECK: func.func @main(
|
# CHECK: func.func @main(
|
||||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
||||||
|
@ -440,7 +444,7 @@ def test_sparse_eltwise():
|
||||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||||
res1 = net(sparse_input)
|
res1 = net(sparse_input)
|
||||||
res2 = sparse_jit(net, sparse_input)
|
res2 = sparse_jit(net, sparse_input)
|
||||||
# TODO: make this work
|
# TODO: make this work in MLIR
|
||||||
# res3 = sparse_jit(net, batch_input)
|
# res3 = sparse_jit(net, batch_input)
|
||||||
print("torch.sparse")
|
print("torch.sparse")
|
||||||
print(res1)
|
print(res1)
|
||||||
|
@ -657,7 +661,14 @@ def test_sparse_network():
|
||||||
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
||||||
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
|
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
|
||||||
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
|
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
|
||||||
|
#
|
||||||
|
# TODO: first row looks suspect...
|
||||||
|
#
|
||||||
# CHECK: torch.mlir
|
# CHECK: torch.mlir
|
||||||
|
# CHECK: {{\[}}[0. 0. 0. 0. ]
|
||||||
|
# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418]
|
||||||
|
# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ]
|
||||||
|
# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
|
||||||
#
|
#
|
||||||
def test_sparse_feature_scaling():
|
def test_sparse_feature_scaling():
|
||||||
class Scale(nn.Module):
|
class Scale(nn.Module):
|
||||||
|
@ -678,11 +689,11 @@ def test_sparse_feature_scaling():
|
||||||
|
|
||||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||||
res1 = net(f)
|
res1 = net(f)
|
||||||
# TODO: make this work
|
res2 = sparse_jit(net, f)
|
||||||
# res2 = sparse_jit(net, f)
|
|
||||||
print("torch.sparse")
|
print("torch.sparse")
|
||||||
print(res1)
|
print(res1)
|
||||||
print("torch.mlir")
|
print("torch.mlir")
|
||||||
|
print(res2)
|
||||||
|
|
||||||
|
|
||||||
@run
|
@run
|
||||||
|
|
Loading…
Reference in New Issue