[torch-mlir][sparse] add JIT test to expose pending issues (#2906)

This test exposes issues that need fixing
(1) propagate sparsity into the FX graph (over elt-wise) (2) batched
dimensions need a new "dense(batch)" format
pull/2909/head
Aart Bik 2024-02-13 13:42:56 -08:00 committed by GitHub
parent 3e836d8dad
commit 24c2fc0b5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 33 additions and 1 deletions

View File

@ -272,6 +272,22 @@ def test_sparse_SpMM():
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
# CHECK: values=tensor({{\[}}[ -1., -2.],
# CHECK: [ -3., -4.],
# ...
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
# CHECK: layout=torch.sparse_csr)
# CHECK: torch.mlir
# CHECK: {{\[\[}}[ -1. -2.]
# CHECK: [ -3. -4.]
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
def test_sparse_eltwise():
class EltNet(torch.nn.Module):
def __init__(self):
@ -280,7 +296,9 @@ def test_sparse_eltwise():
def forward(self, x):
return -x
dense_input = torch.ones(8, 4, 2)
dense_input = torch.reshape(
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
)
# This yields a **batched** CSR.
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
@ -291,3 +309,17 @@ def test_sparse_eltwise():
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
m = export_and_import(EltNet(), sparse_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
#
# TODO: note several issues that need to be fixed
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# (2) for dense_dim=0, this will need a dense(batched) property
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
net = EltNet()
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)