From 5797d3aa572561a51a212a74cac9827be1344964 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 8 Apr 2024 16:46:51 -0700 Subject: [PATCH] [torch-mlir][sparse] add a COO test for 3-dim (#3119) This tests COO for more than 2-dim. Note that sparsity should really propagate into the relu activation and the output, but such cleverness needs to wait for the pending work in the PyTorch tree. --- test/python/fx_importer/sparse_test.py | 35 ++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 93144daf9..bfacde280 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -48,8 +48,8 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta: sparse_dim, dense_dim, blocksize, - a.indices().dtype, - a.indices().dtype, + a._indices().dtype, + a._indices().dtype, ) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: if a.layout is torch.sparse_bsr: @@ -373,3 +373,34 @@ def test_sparse_eltwise(): print(res2) print("torch.mlir.batch") print(res3) + + +@run +# CHECK-LABEL: test_sparse_coo3 +# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> { +# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64> +# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64> +# CHECK: } +# +# TODO: make sure sparsity propagates through relu into the output and test actual JIT output +# +def test_sparse_coo3(): + class COO3Net(torch.nn.Module): + def __init__(self): + super(COO3Net, self).__init__() + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(x) + + net = COO3Net() + + # Direct 3-dim COO construction. + idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]]) + val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64) + sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30]) + + m = export_and_import(net, sparse_input) + print(m)