[torch-mlir][sparse] add decomposition features to sparse compiler (#3505)

Fixes https://github.com/llvm/torch-mlir/issues/3499
pull/3519/head
Aart Bik 2024-06-28 10:18:36 -07:00 committed by GitHub
parent af236dab66
commit 6fece25ff3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 7 deletions

View File

@ -49,6 +49,7 @@ DEFAULT_DECOMPOSITIONS = [
torch.ops.aten.nan_to_num.default, torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind, torch.ops.aten.unbind,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten.diag,
] ]

View File

@ -12,6 +12,7 @@ import torch.export
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir.extras.fx_importer import SparsityMeta
from torch_mlir import ir from torch_mlir import ir
@ -106,6 +107,9 @@ def sparse_export(
# Build the regular FX traced graph with only dense arguments # Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above). # (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs) prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very # Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity. # basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs specs = prog.graph_signature.input_specs
@ -120,7 +124,6 @@ def sparse_export(
node.meta["sparsity"] = sparse_metadata(args[k]) node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1 k = k + 1
elif node.op == "call_function": elif node.op == "call_function":
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1] opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op. # Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}: if opname in {"abs", "neg", "relu", "sin"}:
@ -131,7 +134,7 @@ def sparse_export(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
) )
# TODO: Uncomment this to hack sparsity into the network. # TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense": # elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now) # # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None): elif opname == "select" and node.args[0].meta.get("sparsity", None):
@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs):
compiled = backend.compile(module) compiled = backend.compile(module)
invoker = backend.load(compiled) invoker = backend.load(compiled)
xargs = [] xargs = []
# Prepare the buffer parameters (assume all dense). # Prepare all the named buffer parameters (assume all dense).
# TODO: filters out scalar arguments, anything else? # All scalar arguments are filtered out since they appear inline.
params = dict(f.named_buffers(remove_duplicate=True)) params = dict(f.named_buffers(remove_duplicate=True))
params_flat, params_spec = torch.utils._pytree.tree_flatten(params) params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
for p in params_flat: for p in params_flat:
@ -339,6 +342,7 @@ def test_sparse_SpMV():
@run @run
# #
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main( # CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
@ -440,7 +444,7 @@ def test_sparse_eltwise():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input) res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input) res2 = sparse_jit(net, sparse_input)
# TODO: make this work # TODO: make this work in MLIR
# res3 = sparse_jit(net, batch_input) # res3 = sparse_jit(net, batch_input)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
@ -657,7 +661,14 @@ def test_sparse_network():
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851], # CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185], # CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) # CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
#
# TODO: first row looks suspect...
#
# CHECK: torch.mlir # CHECK: torch.mlir
# CHECK: {{\[}}[0. 0. 0. 0. ]
# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418]
# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ]
# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}}
# #
def test_sparse_feature_scaling(): def test_sparse_feature_scaling():
class Scale(nn.Module): class Scale(nn.Module):
@ -678,11 +689,11 @@ def test_sparse_feature_scaling():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(f) res1 = net(f)
# TODO: make this work res2 = sparse_jit(net, f)
# res2 = sparse_jit(net, f)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
print("torch.mlir") print("torch.mlir")
print(res2)
@run @run