mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] test for sparse "activation" (#3304)
Example of introducing sparsity into the forward pass. With a bespoke propagation (but upstream PyTorch will support this).pull/3305/head
parent
ec6d7aa5d2
commit
c4b28e8d9f
|
@ -123,6 +123,11 @@ def sparse_export(
|
|||
# Zero preserving elt-wise unary op.
|
||||
if node.name in {"abs", "neg", "relu", "sin"}:
|
||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif node.name == "_to_sparse":
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
return prog
|
||||
|
||||
|
||||
|
@ -458,3 +463,51 @@ def test_sparse_coo3():
|
|||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_activation
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> {
|
||||
# CHECK: %[[N1:.*]] = torch.constant.none
|
||||
# CHECK: %[[N2:.*]] = torch.constant.none
|
||||
# CHECK: %[[N3:.*]] = torch.constant.none
|
||||
# CHECK: %[[R:.*]] = torch.operator "torch.aten._to_sparse"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1],
|
||||
# CHECK: [0, 0, 1, 1, 0, 0, 1, 1],
|
||||
# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}),
|
||||
# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
|
||||
# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [0 8]
|
||||
# CHECK: [0 0 0 0 1 1 1 1]
|
||||
# CHECK: [0 0 1 1 0 0 1 1]
|
||||
# CHECK: [0 1 0 1 0 1 0 1]
|
||||
# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.]
|
||||
#
|
||||
def test_sparse_activation():
|
||||
class SparseActivationCOO(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.to_sparse()
|
||||
|
||||
net = SparseActivationCOO()
|
||||
x = torch.ones(2, 2, 2)
|
||||
m = export_and_import(net, x)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(x)
|
||||
res2 = sparse_jit(net, x)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2[0])
|
||||
print(res2[1])
|
||||
print(res2[2])
|
||||
print(res2[3])
|
||||
print(res2[4])
|
||||
|
|
Loading…
Reference in New Issue