mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add a COO test for 3-dim (#3119)
This tests COO for more than 2-dim. Note that sparsity should really propagate into the relu activation and the output, but such cleverness needs to wait for the pending work in the PyTorch tree.pull/3124/head
parent
dd967eb199
commit
5797d3aa57
|
@ -48,8 +48,8 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
|||
sparse_dim,
|
||||
dense_dim,
|
||||
blocksize,
|
||||
a.indices().dtype,
|
||||
a.indices().dtype,
|
||||
a._indices().dtype,
|
||||
a._indices().dtype,
|
||||
)
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
if a.layout is torch.sparse_bsr:
|
||||
|
@ -373,3 +373,34 @@ def test_sparse_eltwise():
|
|||
print(res2)
|
||||
print("torch.mlir.batch")
|
||||
print(res3)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_coo3
|
||||
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64>
|
||||
# CHECK: }
|
||||
#
|
||||
# TODO: make sure sparsity propagates through relu into the output and test actual JIT output
|
||||
#
|
||||
def test_sparse_coo3():
|
||||
class COO3Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(COO3Net, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
net = COO3Net()
|
||||
|
||||
# Direct 3-dim COO construction.
|
||||
idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]])
|
||||
val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64)
|
||||
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30])
|
||||
|
||||
m = export_and_import(net, sparse_input)
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue