[torch-mlir][sparse] add ID-net example (#3127)

first sparse-in/sparse-out example, will be used
to make actual sparse output work!
pull/3130/head
Aart Bik 2024-04-09 11:21:30 -07:00 committed by GitHub
parent 8ff28527cb
commit 184d8c13f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 5 deletions

View File

@ -162,12 +162,12 @@ def sparse_jit(f, *args, **kwargs):
if a.layout is torch.sparse_coo:
# Construct the additional position array required by MLIR with data
# array([0, nnz]).
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
xargs.append(torch.tensor([0, a._nnz()], dtype=a._indices().dtype).numpy())
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
# MLIR SoA COO representation.
for idx in a.indices():
for idx in a._indices():
xargs.append(idx.numpy())
xargs.append(a.values().numpy())
xargs.append(a._values().numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.crow_indices().numpy())
xargs.append(a.col_indices().numpy())
@ -189,6 +189,46 @@ def run(f):
print()
@run
# CHECK-LABEL: test_sparse_id
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> {
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
# CHECK: [ 0, 1, 10, 19]{{\]}}),
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: torch.mlir
#
def test_sparse_id():
class IdNet(torch.nn.Module):
def __init__(self):
super(IdNet, self).__init__()
def forward(self, x):
return x
net = IdNet()
idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]])
val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64)
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20])
m = export_and_import(net, sparse_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# TODO: make output work
res1 = net(sparse_input)
# res2 = sparse_jit(net, sparse_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
# print(res2)
@run
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
@ -362,8 +402,7 @@ def test_sparse_eltwise():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
#
# TODO: note several issues that need to be fixed
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# TODO: propagate sparsity into elt-wise (instead of dense result)
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
res3 = sparse_jit(net, batch_input)