mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add simple sparsity "propagation" rules (#3297)
While waiting for the full resolution of feature request https://github.com/pytorch/pytorch/issues/117188 (which will propagate sparsity the right way in upstream PyTorch for all FX Graphs), this minor change allows us to start testing sparsity "within" a network, rather than just the parameters. Feel free to add your own rules for testing (but within reason for what will be done upstream). Note, two TODOs need to be addressed to work around some pending issues to make the JIT execution work.pull/3294/head
parent
9be6877c22
commit
c77f3b559a
|
@ -94,11 +94,11 @@ def sparse_export(
|
|||
|
||||
is addressed, this wrapper provides support for the sparse
|
||||
tensor types by first converting all operands to dense tensors,
|
||||
building the traced graph as for the dense case, and then
|
||||
annotation sparse parameters with their actual sparse layout
|
||||
attributes. This temporary solution accelerates testing
|
||||
torch-mlir with PyTorch sparse tensors until the issue is
|
||||
resolved.
|
||||
building the traced graph as for the dense case, then annotating
|
||||
sparse parameters with their actual sparse layout attributes,
|
||||
followed by some simple propagation rules. This temporary solution
|
||||
accelerates testing torch-mlir with PyTorch sparse tensors until
|
||||
the issue is resolved upstream.
|
||||
"""
|
||||
# Convert all arguments to dense.
|
||||
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
|
||||
|
@ -106,21 +106,23 @@ def sparse_export(
|
|||
# Build the regular FX traced graph with only dense arguments
|
||||
# (the current version would crash otherwise, see issue above).
|
||||
prog = torch.export.export(f, dargs, kwargs)
|
||||
# Annotate sparse arguments in the graph. Note that we currently
|
||||
# only account for sparsity defined by the user inputs to the model.
|
||||
# TODO: support sparsity in model parameters (weights, biases)
|
||||
# TODO: propagate sparsity into the layers
|
||||
# Annotate sparse arguments in the graph and apply some very
|
||||
# basic propagation rules for sparsity.
|
||||
specs = prog.graph_signature.input_specs
|
||||
alen = len(specs)
|
||||
k = 0
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
if i >= alen:
|
||||
break
|
||||
spec = specs[i]
|
||||
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
|
||||
if mask[k]:
|
||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||
k = k + 1
|
||||
if node.op == "placeholder":
|
||||
# Argument.
|
||||
spec = specs[i]
|
||||
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
|
||||
if mask[k]:
|
||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||
k = k + 1
|
||||
elif node.op == "call_function":
|
||||
# Zero preserving elt-wise unary op.
|
||||
if node.name in {"abs", "neg", "relu", "sin"}:
|
||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
return prog
|
||||
|
||||
|
||||
|
@ -170,8 +172,8 @@ def sparse_jit(f, *args, **kwargs):
|
|||
# Construct the additional position array required by MLIR with data
|
||||
# array([0, nnz]). The COO format always uses int64 indices.
|
||||
xargs.append(np.array([0, a._nnz()], dtype=np.int64))
|
||||
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
|
||||
# MLIR SoA COO representation.
|
||||
# Transform a tensor<ndim x nnz> into ndim x tensor<nnz> to conform
|
||||
# to the MLIR SoA COO representation.
|
||||
for idx in a._indices():
|
||||
xargs.append(idx.numpy())
|
||||
xargs.append(a._values().numpy())
|
||||
|
@ -204,13 +206,16 @@ def run(f):
|
|||
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
|
||||
# CHECK: [ 0, 1, 10, 19]{{\]}}),
|
||||
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
|
||||
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: (array([0, 4]), array([0, 1, 2, 9]), array([ 0, 1, 10, 19]), array([-1000., -1., 1., 1000.]))
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
|
||||
# CHECK: [ 0, 1, 10, 19]{{\]}}),
|
||||
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
|
||||
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [0 4]
|
||||
# CHECK: [0 1 2 9]
|
||||
# CHECK: [ 0 1 10 19]
|
||||
# CHECK: [-1000. -1. 1. 1000.]
|
||||
#
|
||||
def test_sparse_id():
|
||||
class IdNet(torch.nn.Module):
|
||||
|
@ -233,7 +238,10 @@ def test_sparse_id():
|
|||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
print(res2[0])
|
||||
print(res2[1])
|
||||
print(res2[2])
|
||||
print(res2[3])
|
||||
|
||||
|
||||
@run
|
||||
|
@ -315,14 +323,14 @@ def test_sparse_SpMV():
|
|||
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
|
||||
#
|
||||
def test_sparse_SpMM():
|
||||
class MatMulNet(torch.nn.Module):
|
||||
|
@ -349,41 +357,30 @@ def test_sparse_SpMM():
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_eltwise
|
||||
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK: }
|
||||
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
|
||||
# CHECK: }
|
||||
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
|
||||
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
|
||||
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||
# CHECK: [ -3., -4.],
|
||||
# ...
|
||||
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
|
||||
# CHECK: layout=torch.sparse_csr)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[\[}}[ -1. -2.]
|
||||
# CHECK: [ -3. -4.]
|
||||
# ...
|
||||
# CHECK: [-61. -62.]
|
||||
# CHECK: [-63. -64.]{{\]\]}}
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
|
||||
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
|
||||
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||
# ...
|
||||
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
|
||||
# CHECK: layout=torch.sparse_csr)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: torch.mlir.batch
|
||||
#
|
||||
# CHECK: torch.mlir.batch
|
||||
# CHECK: {{\[\[}}[ -1. -2.]
|
||||
# CHECK: [ -3. -4.]
|
||||
# ...
|
||||
# CHECK: [-61. -62.]
|
||||
# CHECK: [-63. -64.]{{\]\]}}
|
||||
def test_sparse_eltwise():
|
||||
class EltNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -397,40 +394,43 @@ def test_sparse_eltwise():
|
|||
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
|
||||
)
|
||||
|
||||
# This yields a **batched** CSR.
|
||||
batch_input = dense_input.to_sparse_csr(dense_dim=0)
|
||||
m = export_and_import(net, batch_input)
|
||||
print(m)
|
||||
|
||||
# This yields a plain CSR with dense **sub**tensor
|
||||
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
||||
m = export_and_import(net, sparse_input)
|
||||
print(m)
|
||||
|
||||
# This yields a **batched** CSR.
|
||||
batch_input = dense_input.to_sparse_csr(dense_dim=0)
|
||||
m = export_and_import(net, batch_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
#
|
||||
# TODO: propagate sparsity into elt-wise (instead of dense result)
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
res3 = sparse_jit(net, batch_input)
|
||||
# TODO: make these work
|
||||
# res2 = sparse_jit(net, sparse_input)
|
||||
# res3 = sparse_jit(net, batch_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
print("torch.mlir.batch")
|
||||
print(res3)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_coo3
|
||||
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64>
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#[[$COO3]]>) -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# TODO: make sure sparsity propagates through relu into the output and test actual JIT output
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 1, 4, 9, 9],
|
||||
# CHECK: [ 0, 1, 1, 5, 19, 19],
|
||||
# CHECK: [ 0, 1, 3, 6, 28, 29]{{\]}}),
|
||||
# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]),
|
||||
# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
#
|
||||
def test_sparse_coo3():
|
||||
class COO3Net(torch.nn.Module):
|
||||
|
@ -450,3 +450,11 @@ def test_sparse_coo3():
|
|||
|
||||
m = export_and_import(net, sparse_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(sparse_input)
|
||||
# TODO: make coo3 work
|
||||
# res2 = sparse_jit(net, sparse_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
|
|
Loading…
Reference in New Issue