diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index b84805163..1c19babe0 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -204,6 +204,7 @@ def run(f): @run +# # CHECK-LABEL: test_sparse_id # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -250,6 +251,7 @@ def test_sparse_id(): @run +# # CHECK-LABEL: test_sparse_sum # CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -284,6 +286,7 @@ def test_sparse_sum(): @run +# # CHECK-LABEL: test_sparse_SpMV # CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -319,6 +322,7 @@ def test_sparse_SpMV(): @run +# # CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -361,6 +365,7 @@ def test_sparse_SpMM(): @run +# # CHECK-LABEL: test_sparse_eltwise # CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -428,6 +433,7 @@ def test_sparse_eltwise(): @run +# # CHECK-LABEL: test_sparse_coo3 # CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -473,6 +479,7 @@ def test_sparse_coo3(): @run +# # CHECK-LABEL: test_sparse_activation # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( @@ -518,3 +525,86 @@ def test_sparse_activation(): print(res2[2]) print(res2[3]) print(res2[4]) + + +@run +# +# CHECK-LABEL: test_sparse_network +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { +# ... lots of IR ... +# CHECK-COUNT-15: torch.aten.mul.Tensor +# ... lots of IR ... +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.]) +# CHECK: torch.mlir +# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.] +# +def test_sparse_network(): + def spike(input): + return (input >= 0).float() + + def sqSum(input): + return (input * input).sum() + + class LIF(nn.Module): + def __init__(self): + super(LIF, self).__init__() + self.thresh = 1.0 + self.decay = 0.5 + self.act = spike + + def forward(self, X): + """A filter that yields a binary-valued sparse tensor.""" + mem = 0 + spike_pot = [] + T = X.size(-1) + for t in range(T): + mem = mem * self.decay + X[..., t] + spike = self.act(mem - self.thresh) + mem = mem * (1.0 - spike) + spike_pot.append(spike) + spike_pot = torch.stack(spike_pot, dim=-1) + # TODO: we would like to see something like + # return spike_pot.to_sparse() + return spike_pot + + class tdLayer(nn.Module): + def __init__(self, layer): + super(tdLayer, self).__init__() + self.layer = layer + + def forward(self, X): + T = X.size(-1) + out = [] + for t in range(T): + m = self.layer(X[..., t]) + out.append(m) + out = torch.stack(out, dim=-1) + return out + + class Block(nn.Module): + def __init__(self): + super(Block, self).__init__() + self.spike = LIF() + self.layer = tdLayer(sqSum) + + def forward(self, X): + out = self.spike(X) + out = self.layer(out) + return out + + net = Block() + x = torch.ones(2, 3, 8, 8) + m = export_and_import(net, x) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(x) + res2 = sparse_jit(net, x) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2)