mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add JIT test to expose pending issues (#2906)
This test exposes issues that need fixing (1) propagate sparsity into the FX graph (over elt-wise) (2) batched dimensions need a new "dense(batch)" formatpull/2909/head
parent
3e836d8dad
commit
24c2fc0b5f
|
@ -272,6 +272,22 @@ def test_sparse_SpMM():
|
|||
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
|
||||
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
|
||||
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||
# CHECK: [ -3., -4.],
|
||||
# ...
|
||||
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
|
||||
# CHECK: layout=torch.sparse_csr)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[\[}}[ -1. -2.]
|
||||
# CHECK: [ -3. -4.]
|
||||
# ...
|
||||
# CHECK: [-61. -62.]
|
||||
# CHECK: [-63. -64.]{{\]\]}}
|
||||
def test_sparse_eltwise():
|
||||
class EltNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -280,7 +296,9 @@ def test_sparse_eltwise():
|
|||
def forward(self, x):
|
||||
return -x
|
||||
|
||||
dense_input = torch.ones(8, 4, 2)
|
||||
dense_input = torch.reshape(
|
||||
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
|
||||
)
|
||||
|
||||
# This yields a **batched** CSR.
|
||||
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
|
||||
|
@ -291,3 +309,17 @@ def test_sparse_eltwise():
|
|||
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
||||
m = export_and_import(EltNet(), sparse_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
#
|
||||
# TODO: note several issues that need to be fixed
|
||||
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
|
||||
# (2) for dense_dim=0, this will need a dense(batched) property
|
||||
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
||||
net = EltNet()
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
|
Loading…
Reference in New Issue