[torch-mlir][sparse] minor tweaks in sparse tests (#3311)

(1) test full pytorch output for eltwise
(2) use "random" input for LIF, to get general sparse tensor 
(3) introduce way to get true sparsity into network (needs backend fix
first)
pull/3313/head
Aart Bik 2024-05-09 10:03:25 -07:00 committed by GitHub
parent a033bbfe6c
commit 97a822de0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 8 deletions

View File

@ -128,6 +128,10 @@ def sparse_export(
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif node.name == "_to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
return prog
@ -384,10 +388,14 @@ def test_sparse_SpMM():
# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
# CHECK: values=tensor({{\[}}[ -1., -2.],
# ...
# CHECK: [ -3., -4.],
# CHECK: [ -5., -6.],
# CHECK: [ -7., -8.],
# CHECK: [ -9., -10.],
# CHECK: [-11., -12.],
# CHECK: [-13., -14.],
# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
# CHECK: layout=torch.sparse_csr)
#
# CHECK: torch.mlir
# CHECK: [0 2 4 6 8]
# CHECK: [0 1 0 1 0 1 0 1]
@ -421,7 +429,7 @@ def test_sparse_eltwise():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
# TODO: make these work
# TODO: make this work
# res3 = sparse_jit(net, batch_input)
print("torch.sparse")
print(res1)
@ -538,9 +546,9 @@ def test_sparse_activation():
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.])
# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
# CHECK: torch.mlir
# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.]
# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.]
#
def test_sparse_network():
def spike(input):
@ -565,10 +573,9 @@ def test_sparse_network():
mem = mem * self.decay + X[..., t]
spike = self.act(mem - self.thresh)
mem = mem * (1.0 - spike)
spike = spike.to_sparse().to_dense() # prop hack
spike_pot.append(spike)
spike_pot = torch.stack(spike_pot, dim=-1)
# TODO: we would like to see something like
# return spike_pot.to_sparse()
return spike_pot
class tdLayer(nn.Module):
@ -597,7 +604,11 @@ def test_sparse_network():
return out
net = Block()
x = torch.ones(2, 3, 8, 8)
# Get a random (but reproducible) input, so that a
# general sparse tensor appears after LIF.
torch.manual_seed(0)
x = torch.rand(2, 3, 8, 8)
m = export_and_import(net, x)
print(m)