mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] re-enable all sparse tests (#3444)
this fixes the following issue: https://github.com/llvm/torch-mlir/issues/3418pull/3450/head
parent
7e0e23c668
commit
d77bab37d1
|
@ -2579,6 +2579,8 @@ private:
|
|||
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
|
||||
"torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr",
|
||||
"torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc",
|
||||
"torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr",
|
||||
"torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc",
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ def sparse_export(
|
|||
# Zero preserving elt-wise unary op.
|
||||
if opname in {"abs", "neg", "relu", "sin"}:
|
||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif opname == "_to_sparse":
|
||||
elif opname == "_to_sparse" or opname == "to_sparse":
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
|
@ -339,6 +339,14 @@ def test_sparse_SpMV():
|
|||
|
||||
@run
|
||||
#
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
||||
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
|
||||
# CHECK: }
|
||||
##
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
|
@ -360,7 +368,7 @@ def test_sparse_SpMM():
|
|||
dense_input = torch.ones(8, 8)
|
||||
sparse_input = dense_input.to_sparse_coo()
|
||||
m = export_and_import(net, sparse_input, dense_input)
|
||||
# print(m)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(sparse_input, dense_input)
|
||||
|
@ -500,12 +508,29 @@ def test_sparse_coo3():
|
|||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_activation
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> {
|
||||
# CHECK: %[[N1:.*]] = torch.constant.none
|
||||
# CHECK: %[[N2:.*]] = torch.constant.none
|
||||
# CHECK: %[[N3:.*]] = torch.constant.none
|
||||
# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1],
|
||||
# CHECK: [0, 0, 1, 1, 0, 0, 1, 1],
|
||||
# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}),
|
||||
# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
|
||||
# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [0 8]
|
||||
# CHECK: [0 0 0 0 1 1 1 1]
|
||||
# CHECK: [0 0 1 1 0 0 1 1]
|
||||
# CHECK: [0 1 0 1 0 1 0 1]
|
||||
# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.]
|
||||
#
|
||||
def test_sparse_activation():
|
||||
class SparseActivationCOO(torch.nn.Module):
|
||||
|
@ -515,19 +540,19 @@ def test_sparse_activation():
|
|||
net = SparseActivationCOO()
|
||||
x = torch.ones(2, 2, 2)
|
||||
m = export_and_import(net, x)
|
||||
# print(m)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(x)
|
||||
# res2 = sparse_jit(net, x)
|
||||
res2 = sparse_jit(net, x)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
# print("torch.mlir")
|
||||
# print(res2[0])
|
||||
# print(res2[1])
|
||||
# print(res2[2])
|
||||
# print(res2[3])
|
||||
# print(res2[4])
|
||||
print("torch.mlir")
|
||||
print(res2[0])
|
||||
print(res2[1])
|
||||
print(res2[2])
|
||||
print(res2[3])
|
||||
print(res2[4])
|
||||
|
||||
|
||||
@run
|
||||
|
@ -542,6 +567,8 @@ def test_sparse_activation():
|
|||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.]
|
||||
#
|
||||
def test_sparse_network():
|
||||
def spike(input):
|
||||
|
@ -607,15 +634,24 @@ def test_sparse_network():
|
|||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(x)
|
||||
# res2 = sparse_jit(net, x)
|
||||
res2 = sparse_jit(net, x)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
# print("torch.mlir")
|
||||
# print(res2)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_feature_scaling
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
|
||||
# ... more IR ...
|
||||
# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"
|
||||
# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]]
|
||||
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
|
||||
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
||||
|
@ -638,7 +674,7 @@ def test_sparse_feature_scaling():
|
|||
torch.manual_seed(0)
|
||||
f = torch.rand(4, 4)
|
||||
m = export_and_import(net, f)
|
||||
# print(m)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(f)
|
||||
|
|
Loading…
Reference in New Issue