mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] add ID-net example (#3127)
first sparse-in/sparse-out example, will be used to make actual sparse output work!pull/3130/head
parent
8ff28527cb
commit
184d8c13f4
|
@ -162,12 +162,12 @@ def sparse_jit(f, *args, **kwargs):
|
|||
if a.layout is torch.sparse_coo:
|
||||
# Construct the additional position array required by MLIR with data
|
||||
# array([0, nnz]).
|
||||
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
|
||||
xargs.append(torch.tensor([0, a._nnz()], dtype=a._indices().dtype).numpy())
|
||||
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
|
||||
# MLIR SoA COO representation.
|
||||
for idx in a.indices():
|
||||
for idx in a._indices():
|
||||
xargs.append(idx.numpy())
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a._values().numpy())
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
xargs.append(a.crow_indices().numpy())
|
||||
xargs.append(a.col_indices().numpy())
|
||||
|
@ -189,6 +189,46 @@ def run(f):
|
|||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_id
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> {
|
||||
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
|
||||
# CHECK: [ 0, 1, 10, 19]{{\]}}),
|
||||
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
|
||||
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
#
|
||||
def test_sparse_id():
|
||||
class IdNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(IdNet, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
net = IdNet()
|
||||
idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]])
|
||||
val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64)
|
||||
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20])
|
||||
m = export_and_import(net, sparse_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
# TODO: make output work
|
||||
res1 = net(sparse_input)
|
||||
# res2 = sparse_jit(net, sparse_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
# print(res2)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_sum
|
||||
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
|
@ -362,8 +402,7 @@ def test_sparse_eltwise():
|
|||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
#
|
||||
# TODO: note several issues that need to be fixed
|
||||
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
|
||||
# TODO: propagate sparsity into elt-wise (instead of dense result)
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
res3 = sparse_jit(net, batch_input)
|
||||
|
|
Loading…
Reference in New Issue