mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] support e2e sparse kernels with COO inputs. (#2939)
parent
ed6e75908b
commit
e85a2a87c5
|
@ -273,7 +273,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
|||
|
||||
if sparsity.layout is torch.sparse_coo:
|
||||
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
|
||||
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton"
|
||||
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)"
|
||||
elif sparsity.layout is torch.sparse_csr:
|
||||
assert sparse_dim == 2 and blocksize is None
|
||||
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
|
||||
|
|
|
@ -161,7 +161,13 @@ def sparse_jit(f, *args, **kwargs):
|
|||
for a in args:
|
||||
if a.layout is torch.sparse_coo:
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a.indices().numpy())
|
||||
# Construct the additional position array required by MLIR with data
|
||||
# array([0, nnz]).
|
||||
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
|
||||
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
|
||||
# MLIR SoA COO representation.
|
||||
for idx in a.indices():
|
||||
xargs.append(idx.numpy())
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a.crow_indices().numpy())
|
||||
|
@ -254,7 +260,7 @@ def test_sparse_SpMV():
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_SpMM
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
||||
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
|
||||
|
@ -286,8 +292,7 @@ def test_sparse_SpMM():
|
|||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
# TODO: run with COO, right now only CSR works
|
||||
sparse_input = dense_input.to_sparse_csr()
|
||||
net = MatMulNet()
|
||||
res1 = net(sparse_input, dense_input)
|
||||
res2 = sparse_jit(net, sparse_input, dense_input)
|
||||
print("torch.sparse")
|
||||
|
|
Loading…
Reference in New Issue