mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] minor tweaks in sparse tests (#3311)
(1) test full pytorch output for eltwise (2) use "random" input for LIF, to get general sparse tensor (3) introduce way to get true sparsity into network (needs backend fix first)pull/3313/head
parent
a033bbfe6c
commit
97a822de0a
|
@ -128,6 +128,10 @@ def sparse_export(
|
|||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
# TODO: Uncomment this to hack sparsity into the network.
|
||||
# elif node.name == "_to_dense":
|
||||
# # hack (assumes we never really want the to_dense for now)
|
||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
return prog
|
||||
|
||||
|
||||
|
@ -384,10 +388,14 @@ def test_sparse_SpMM():
|
|||
# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
|
||||
# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
|
||||
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||
# ...
|
||||
# CHECK: [ -3., -4.],
|
||||
# CHECK: [ -5., -6.],
|
||||
# CHECK: [ -7., -8.],
|
||||
# CHECK: [ -9., -10.],
|
||||
# CHECK: [-11., -12.],
|
||||
# CHECK: [-13., -14.],
|
||||
# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
|
||||
# CHECK: layout=torch.sparse_csr)
|
||||
#
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [0 2 4 6 8]
|
||||
# CHECK: [0 1 0 1 0 1 0 1]
|
||||
|
@ -421,7 +429,7 @@ def test_sparse_eltwise():
|
|||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
# TODO: make these work
|
||||
# TODO: make this work
|
||||
# res3 = sparse_jit(net, batch_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
|
@ -538,9 +546,9 @@ def test_sparse_activation():
|
|||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.])
|
||||
# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.]
|
||||
# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.]
|
||||
#
|
||||
def test_sparse_network():
|
||||
def spike(input):
|
||||
|
@ -565,10 +573,9 @@ def test_sparse_network():
|
|||
mem = mem * self.decay + X[..., t]
|
||||
spike = self.act(mem - self.thresh)
|
||||
mem = mem * (1.0 - spike)
|
||||
spike = spike.to_sparse().to_dense() # prop hack
|
||||
spike_pot.append(spike)
|
||||
spike_pot = torch.stack(spike_pot, dim=-1)
|
||||
# TODO: we would like to see something like
|
||||
# return spike_pot.to_sparse()
|
||||
return spike_pot
|
||||
|
||||
class tdLayer(nn.Module):
|
||||
|
@ -597,7 +604,11 @@ def test_sparse_network():
|
|||
return out
|
||||
|
||||
net = Block()
|
||||
x = torch.ones(2, 3, 8, 8)
|
||||
|
||||
# Get a random (but reproducible) input, so that a
|
||||
# general sparse tensor appears after LIF.
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(2, 3, 8, 8)
|
||||
m = export_and_import(net, x)
|
||||
print(m)
|
||||
|
||||
|
|
Loading…
Reference in New Issue