mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add JIT test to expose pending issues (#2906)
This test exposes issues that need fixing (1) propagate sparsity into the FX graph (over elt-wise) (2) batched dimensions need a new "dense(batch)" formatpull/2909/head
parent
3e836d8dad
commit
24c2fc0b5f
|
@ -272,6 +272,22 @@ def test_sparse_SpMM():
|
||||||
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
|
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32>
|
||||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
|
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
#
|
||||||
|
# CHECK: torch.sparse
|
||||||
|
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
|
||||||
|
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
|
||||||
|
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||||
|
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||||
|
# CHECK: [ -3., -4.],
|
||||||
|
# ...
|
||||||
|
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
|
||||||
|
# CHECK: layout=torch.sparse_csr)
|
||||||
|
# CHECK: torch.mlir
|
||||||
|
# CHECK: {{\[\[}}[ -1. -2.]
|
||||||
|
# CHECK: [ -3. -4.]
|
||||||
|
# ...
|
||||||
|
# CHECK: [-61. -62.]
|
||||||
|
# CHECK: [-63. -64.]{{\]\]}}
|
||||||
def test_sparse_eltwise():
|
def test_sparse_eltwise():
|
||||||
class EltNet(torch.nn.Module):
|
class EltNet(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -280,7 +296,9 @@ def test_sparse_eltwise():
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return -x
|
return -x
|
||||||
|
|
||||||
dense_input = torch.ones(8, 4, 2)
|
dense_input = torch.reshape(
|
||||||
|
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
|
||||||
|
)
|
||||||
|
|
||||||
# This yields a **batched** CSR.
|
# This yields a **batched** CSR.
|
||||||
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
|
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
|
||||||
|
@ -291,3 +309,17 @@ def test_sparse_eltwise():
|
||||||
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
||||||
m = export_and_import(EltNet(), sparse_input)
|
m = export_and_import(EltNet(), sparse_input)
|
||||||
print(m)
|
print(m)
|
||||||
|
|
||||||
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||||
|
#
|
||||||
|
# TODO: note several issues that need to be fixed
|
||||||
|
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
|
||||||
|
# (2) for dense_dim=0, this will need a dense(batched) property
|
||||||
|
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
||||||
|
net = EltNet()
|
||||||
|
res1 = net(sparse_input)
|
||||||
|
res2 = sparse_jit(net, sparse_input)
|
||||||
|
print("torch.sparse")
|
||||||
|
print(res1)
|
||||||
|
print("torch.mlir")
|
||||||
|
print(res2)
|
||||||
|
|
Loading…
Reference in New Issue