[torch-mlir][sparse] support e2e sparse kernels with COO inputs. (#2939)

pull/2967/head
Peiming Liu 2024-02-28 16:08:37 -08:00 committed by GitHub
parent ed6e75908b
commit e85a2a87c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 5 deletions

View File

@ -273,7 +273,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
if sparsity.layout is torch.sparse_coo:
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton"
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"

View File

@ -161,7 +161,13 @@ def sparse_jit(f, *args, **kwargs):
for a in args:
if a.layout is torch.sparse_coo:
xargs.append(a.values().numpy())
xargs.append(a.indices().numpy())
# Construct the additional position array required by MLIR with data
# array([0, nnz]).
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
# MLIR SoA COO representation.
for idx in a.indices():
xargs.append(idx.numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.values().numpy())
xargs.append(a.crow_indices().numpy())
@ -254,7 +260,7 @@ def test_sparse_SpMV():
@run
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
@ -286,8 +292,7 @@ def test_sparse_SpMM():
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# TODO: run with COO, right now only CSR works
sparse_input = dense_input.to_sparse_csr()
net = MatMulNet()
res1 = net(sparse_input, dense_input)
res2 = sparse_jit(net, sparse_input, dense_input)
print("torch.sparse")