From 24c2fc0b5f90d870cbfd967c81460bcc5686d24d Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 13 Feb 2024 13:42:56 -0800 Subject: [PATCH] [torch-mlir][sparse] add JIT test to expose pending issues (#2906) This test exposes issues that need fixing (1) propagate sparsity into the FX graph (over elt-wise) (2) batched dimensions need a new "dense(batch)" format --- test/python/fx_importer/sparse_test.py | 34 +++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 161b29148..d0b94ac83 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -272,6 +272,22 @@ def test_sparse_SpMM(): # CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32> # CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> # CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# CHECK: [ -3., -4.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -280,7 +296,9 @@ def test_sparse_eltwise(): def forward(self, x): return -x - dense_input = torch.ones(8, 4, 2) + dense_input = torch.reshape( + torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) + ) # This yields a **batched** CSR. sparse_input = dense_input.to_sparse_csr(dense_dim=0) @@ -291,3 +309,17 @@ def test_sparse_eltwise(): sparse_input = dense_input.to_sparse_csr(dense_dim=1) m = export_and_import(EltNet(), sparse_input) print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # + # TODO: note several issues that need to be fixed + # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result + # (2) for dense_dim=0, this will need a dense(batched) property + sparse_input = dense_input.to_sparse_csr(dense_dim=1) + net = EltNet() + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2)