[torch-mlir][sparse] add a COO test for 3-dim (#3119)

This tests COO for more than 2-dim. Note that sparsity should really
propagate into the relu activation and the output, but such cleverness
needs to wait for the pending work in the PyTorch tree.
pull/3124/head
Aart Bik 2024-04-08 16:46:51 -07:00 committed by GitHub
parent dd967eb199
commit 5797d3aa57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 33 additions and 2 deletions

View File

@ -48,8 +48,8 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
sparse_dim, sparse_dim,
dense_dim, dense_dim,
blocksize, blocksize,
a.indices().dtype, a._indices().dtype,
a.indices().dtype, a._indices().dtype,
) )
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr: if a.layout is torch.sparse_bsr:
@ -373,3 +373,34 @@ def test_sparse_eltwise():
print(res2) print(res2)
print("torch.mlir.batch") print("torch.mlir.batch")
print(res3) print(res3)
@run
# CHECK-LABEL: test_sparse_coo3
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> {
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64>
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64>
# CHECK: }
#
# TODO: make sure sparsity propagates through relu into the output and test actual JIT output
#
def test_sparse_coo3():
class COO3Net(torch.nn.Module):
def __init__(self):
super(COO3Net, self).__init__()
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(x)
net = COO3Net()
# Direct 3-dim COO construction.
idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]])
val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64)
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30])
m = export_and_import(net, sparse_input)
print(m)