[torch-mlir][sparse] add simple sparsity "propagation" rules (#3297)

While waiting for the full resolution of feature request
https://github.com/pytorch/pytorch/issues/117188
(which will propagate sparsity the right way in upstream PyTorch for all
FX Graphs), this minor change allows us to start testing sparsity
"within" a network, rather than just the parameters. Feel free to add
your own rules for testing (but within reason for what will be done
upstream).

Note, two TODOs need to be addressed to work around some pending issues
to make the JIT execution work.
pull/3294/head
Aart Bik 2024-05-07 15:27:36 -07:00 committed by GitHub
parent 9be6877c22
commit c77f3b559a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 87 additions and 79 deletions

View File

@ -94,11 +94,11 @@ def sparse_export(
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, and then
annotation sparse parameters with their actual sparse layout
attributes. This temporary solution accelerates testing
torch-mlir with PyTorch sparse tensors until the issue is
resolved.
building the traced graph as for the dense case, then annotating
sparse parameters with their actual sparse layout attributes,
followed by some simple propagation rules. This temporary solution
accelerates testing torch-mlir with PyTorch sparse tensors until
the issue is resolved upstream.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
@ -106,21 +106,23 @@ def sparse_export(
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
# Annotate sparse arguments in the graph. Note that we currently
# only account for sparsity defined by the user inputs to the model.
# TODO: support sparsity in model parameters (weights, biases)
# TODO: propagate sparsity into the layers
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if i >= alen:
break
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
if node.op == "placeholder":
# Argument.
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
# Zero preserving elt-wise unary op.
if node.name in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
return prog
@ -170,8 +172,8 @@ def sparse_jit(f, *args, **kwargs):
# Construct the additional position array required by MLIR with data
# array([0, nnz]). The COO format always uses int64 indices.
xargs.append(np.array([0, a._nnz()], dtype=np.int64))
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
# MLIR SoA COO representation.
# Transform a tensor<ndim x nnz> into ndim x tensor<nnz> to conform
# to the MLIR SoA COO representation.
for idx in a._indices():
xargs.append(idx.numpy())
xargs.append(a._values().numpy())
@ -204,13 +206,16 @@ def run(f):
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
# CHECK: [ 0, 1, 10, 19]{{\]}}),
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: torch.mlir
# CHECK: (array([0, 4]), array([0, 1, 2, 9]), array([ 0, 1, 10, 19]), array([-1000., -1., 1., 1000.]))
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
# CHECK: [ 0, 1, 10, 19]{{\]}}),
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: torch.mlir
# CHECK: [0 4]
# CHECK: [0 1 2 9]
# CHECK: [ 0 1 10 19]
# CHECK: [-1000. -1. 1. 1000.]
#
def test_sparse_id():
class IdNet(torch.nn.Module):
@ -233,7 +238,10 @@ def test_sparse_id():
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)
print(res2[0])
print(res2[1])
print(res2[2])
print(res2[3])
@run
@ -315,14 +323,14 @@ def test_sparse_SpMV():
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
# CHECK: torch.mlir
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
# CHECK: torch.mlir
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
#
def test_sparse_SpMM():
class MatMulNet(torch.nn.Module):
@ -349,41 +357,30 @@ def test_sparse_SpMM():
@run
# CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> {
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
# CHECK: }
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32> {
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
# CHECK: }
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
# CHECK: values=tensor({{\[}}[ -1., -2.],
# CHECK: [ -3., -4.],
# ...
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
# CHECK: layout=torch.sparse_csr)
# CHECK: torch.mlir
# CHECK: {{\[\[}}[ -1. -2.]
# CHECK: [ -3. -4.]
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
# CHECK: torch.sparse
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
# CHECK: values=tensor({{\[}}[ -1., -2.],
# ...
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
# CHECK: layout=torch.sparse_csr)
# CHECK: torch.mlir
# CHECK: torch.mlir.batch
#
# CHECK: torch.mlir.batch
# CHECK: {{\[\[}}[ -1. -2.]
# CHECK: [ -3. -4.]
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
def test_sparse_eltwise():
class EltNet(torch.nn.Module):
def __init__(self):
@ -397,40 +394,43 @@ def test_sparse_eltwise():
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
)
# This yields a **batched** CSR.
batch_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(net, batch_input)
print(m)
# This yields a plain CSR with dense **sub**tensor
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
m = export_and_import(net, sparse_input)
print(m)
# This yields a **batched** CSR.
batch_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(net, batch_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
#
# TODO: propagate sparsity into elt-wise (instead of dense result)
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
res3 = sparse_jit(net, batch_input)
# TODO: make these work
# res2 = sparse_jit(net, sparse_input)
# res3 = sparse_jit(net, batch_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)
print("torch.mlir.batch")
print(res3)
@run
# CHECK-LABEL: test_sparse_coo3
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> {
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64>
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64>
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#[[$COO3]]>) -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> {
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
# CHECK: }
#
# TODO: make sure sparsity propagates through relu into the output and test actual JIT output
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 1, 4, 9, 9],
# CHECK: [ 0, 1, 1, 5, 19, 19],
# CHECK: [ 0, 1, 3, 6, 28, 29]{{\]}}),
# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]),
# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: torch.mlir
#
def test_sparse_coo3():
class COO3Net(torch.nn.Module):
@ -450,3 +450,11 @@ def test_sparse_coo3():
m = export_and_import(net, sparse_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input)
# TODO: make coo3 work
# res2 = sparse_jit(net, sparse_input)
print("torch.sparse")
print(res1)
print("torch.mlir")