mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add a true network to our NN tests (#3305)
Objective: make the to_sparse work end-to-end!pull/3308/head
parent
cff144b3ac
commit
89bb7404c1
|
@ -204,6 +204,7 @@ def run(f):
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_id
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -250,6 +251,7 @@ def test_sparse_id():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_sum
|
||||
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -284,6 +286,7 @@ def test_sparse_sum():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_SpMV
|
||||
# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -319,6 +322,7 @@ def test_sparse_SpMV():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_SpMM
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -361,6 +365,7 @@ def test_sparse_SpMM():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_eltwise
|
||||
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -428,6 +433,7 @@ def test_sparse_eltwise():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_coo3
|
||||
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -473,6 +479,7 @@ def test_sparse_coo3():
|
|||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_activation
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
|
@ -518,3 +525,86 @@ def test_sparse_activation():
|
|||
print(res2[2])
|
||||
print(res2[3])
|
||||
print(res2[4])
|
||||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_network
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
|
||||
# ... lots of IR ...
|
||||
# CHECK-COUNT-15: torch.aten.mul.Tensor
|
||||
# ... lots of IR ...
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.])
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.]
|
||||
#
|
||||
def test_sparse_network():
|
||||
def spike(input):
|
||||
return (input >= 0).float()
|
||||
|
||||
def sqSum(input):
|
||||
return (input * input).sum()
|
||||
|
||||
class LIF(nn.Module):
|
||||
def __init__(self):
|
||||
super(LIF, self).__init__()
|
||||
self.thresh = 1.0
|
||||
self.decay = 0.5
|
||||
self.act = spike
|
||||
|
||||
def forward(self, X):
|
||||
"""A filter that yields a binary-valued sparse tensor."""
|
||||
mem = 0
|
||||
spike_pot = []
|
||||
T = X.size(-1)
|
||||
for t in range(T):
|
||||
mem = mem * self.decay + X[..., t]
|
||||
spike = self.act(mem - self.thresh)
|
||||
mem = mem * (1.0 - spike)
|
||||
spike_pot.append(spike)
|
||||
spike_pot = torch.stack(spike_pot, dim=-1)
|
||||
# TODO: we would like to see something like
|
||||
# return spike_pot.to_sparse()
|
||||
return spike_pot
|
||||
|
||||
class tdLayer(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super(tdLayer, self).__init__()
|
||||
self.layer = layer
|
||||
|
||||
def forward(self, X):
|
||||
T = X.size(-1)
|
||||
out = []
|
||||
for t in range(T):
|
||||
m = self.layer(X[..., t])
|
||||
out.append(m)
|
||||
out = torch.stack(out, dim=-1)
|
||||
return out
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self):
|
||||
super(Block, self).__init__()
|
||||
self.spike = LIF()
|
||||
self.layer = tdLayer(sqSum)
|
||||
|
||||
def forward(self, X):
|
||||
out = self.spike(X)
|
||||
out = self.layer(out)
|
||||
return out
|
||||
|
||||
net = Block()
|
||||
x = torch.ones(2, 3, 8, 8)
|
||||
m = export_and_import(net, x)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(x)
|
||||
res2 = sparse_jit(net, x)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
|
Loading…
Reference in New Issue