[torch-mlir][sparse] add a true network to our NN tests (#3305)

Objective: make the to_sparse work end-to-end!
pull/3308/head
Aart Bik 2024-05-08 21:18:42 -07:00 committed by GitHub
parent cff144b3ac
commit 89bb7404c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 90 additions and 0 deletions

View File

@ -204,6 +204,7 @@ def run(f):
@run
#
# CHECK-LABEL: test_sparse_id
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -250,6 +251,7 @@ def test_sparse_id():
@run
#
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -284,6 +286,7 @@ def test_sparse_sum():
@run
#
# CHECK-LABEL: test_sparse_SpMV
# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -319,6 +322,7 @@ def test_sparse_SpMV():
@run
#
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -361,6 +365,7 @@ def test_sparse_SpMM():
@run
#
# CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -428,6 +433,7 @@ def test_sparse_eltwise():
@run
#
# CHECK-LABEL: test_sparse_coo3
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -473,6 +479,7 @@ def test_sparse_coo3():
@run
#
# CHECK-LABEL: test_sparse_activation
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
@ -518,3 +525,86 @@ def test_sparse_activation():
print(res2[2])
print(res2[3])
print(res2[4])
@run
#
# CHECK-LABEL: test_sparse_network
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
# ... lots of IR ...
# CHECK-COUNT-15: torch.aten.mul.Tensor
# ... lots of IR ...
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.])
# CHECK: torch.mlir
# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.]
#
def test_sparse_network():
def spike(input):
return (input >= 0).float()
def sqSum(input):
return (input * input).sum()
class LIF(nn.Module):
def __init__(self):
super(LIF, self).__init__()
self.thresh = 1.0
self.decay = 0.5
self.act = spike
def forward(self, X):
"""A filter that yields a binary-valued sparse tensor."""
mem = 0
spike_pot = []
T = X.size(-1)
for t in range(T):
mem = mem * self.decay + X[..., t]
spike = self.act(mem - self.thresh)
mem = mem * (1.0 - spike)
spike_pot.append(spike)
spike_pot = torch.stack(spike_pot, dim=-1)
# TODO: we would like to see something like
# return spike_pot.to_sparse()
return spike_pot
class tdLayer(nn.Module):
def __init__(self, layer):
super(tdLayer, self).__init__()
self.layer = layer
def forward(self, X):
T = X.size(-1)
out = []
for t in range(T):
m = self.layer(X[..., t])
out.append(m)
out = torch.stack(out, dim=-1)
return out
class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.spike = LIF()
self.layer = tdLayer(sqSum)
def forward(self, X):
out = self.spike(X)
out = self.layer(out)
return out
net = Block()
x = torch.ones(2, 3, 8, 8)
m = export_and_import(net, x)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(x)
res2 = sparse_jit(net, x)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)