[torch-mlir][sparse] test for sparse "activation" (#3304)

Example of introducing sparsity into the forward pass. With a bespoke
propagation (but upstream PyTorch will support this).
pull/3305/head
Aart Bik 2024-05-08 19:01:24 -07:00 committed by GitHub
parent ec6d7aa5d2
commit c4b28e8d9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 53 additions and 0 deletions

View File

@ -123,6 +123,11 @@ def sparse_export(
# Zero preserving elt-wise unary op.
if node.name in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif node.name == "_to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
return prog
@ -458,3 +463,51 @@ def test_sparse_coo3():
print("torch.sparse")
print(res1)
print("torch.mlir")
@run
# CHECK-LABEL: test_sparse_activation
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> {
# CHECK: %[[N1:.*]] = torch.constant.none
# CHECK: %[[N2:.*]] = torch.constant.none
# CHECK: %[[N3:.*]] = torch.constant.none
# CHECK: %[[R:.*]] = torch.operator "torch.aten._to_sparse"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]>
# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1],
# CHECK: [0, 0, 1, 1, 0, 0, 1, 1],
# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}),
# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo)
# CHECK: torch.mlir
# CHECK: [0 8]
# CHECK: [0 0 0 0 1 1 1 1]
# CHECK: [0 0 1 1 0 0 1 1]
# CHECK: [0 1 0 1 0 1 0 1]
# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.]
#
def test_sparse_activation():
class SparseActivationCOO(torch.nn.Module):
def forward(self, x):
return x.to_sparse()
net = SparseActivationCOO()
x = torch.ones(2, 2, 2)
m = export_and_import(net, x)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(x)
res2 = sparse_jit(net, x)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2[0])
print(res2[1])
print(res2[2])
print(res2[3])
print(res2[4])