[torch-mlir][sparse] re-enable all sparse tests (#3444)

this fixes the following issue:

https://github.com/llvm/torch-mlir/issues/3418
pull/3450/head
Aart Bik 2024-06-10 11:19:32 -07:00 committed by GitHub
parent 7e0e23c668
commit d77bab37d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 14 deletions

View File

@ -2579,6 +2579,8 @@ private:
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = { SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
"torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr",
"torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc",
"torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr",
"torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc",
}; };
} // namespace } // namespace

View File

@ -125,7 +125,7 @@ def sparse_export(
# Zero preserving elt-wise unary op. # Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}: if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "_to_sparse": elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape) dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta( node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
@ -339,6 +339,14 @@ def test_sparse_SpMV():
@run @run
# #
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
# CHECK: }
##
# CHECK: torch.sparse # CHECK: torch.sparse
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], # CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], # CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
@ -360,7 +368,7 @@ def test_sparse_SpMM():
dense_input = torch.ones(8, 8) dense_input = torch.ones(8, 8)
sparse_input = dense_input.to_sparse_coo() sparse_input = dense_input.to_sparse_coo()
m = export_and_import(net, sparse_input, dense_input) m = export_and_import(net, sparse_input, dense_input)
# print(m) print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input, dense_input) res1 = net(sparse_input, dense_input)
@ -500,12 +508,29 @@ def test_sparse_coo3():
@run @run
# #
# CHECK-LABEL: test_sparse_activation
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> {
# CHECK: %[[N1:.*]] = torch.constant.none
# CHECK: %[[N2:.*]] = torch.constant.none
# CHECK: %[[N3:.*]] = torch.constant.none
# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]>
# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]>
# CHECK: }
#
# CHECK: torch.sparse # CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1], # CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1],
# CHECK: [0, 0, 1, 1, 0, 0, 1, 1], # CHECK: [0, 0, 1, 1, 0, 0, 1, 1],
# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}), # CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}),
# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]), # CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo) # CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo)
# CHECK: torch.mlir
# CHECK: [0 8]
# CHECK: [0 0 0 0 1 1 1 1]
# CHECK: [0 0 1 1 0 0 1 1]
# CHECK: [0 1 0 1 0 1 0 1]
# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.]
# #
def test_sparse_activation(): def test_sparse_activation():
class SparseActivationCOO(torch.nn.Module): class SparseActivationCOO(torch.nn.Module):
@ -515,19 +540,19 @@ def test_sparse_activation():
net = SparseActivationCOO() net = SparseActivationCOO()
x = torch.ones(2, 2, 2) x = torch.ones(2, 2, 2)
m = export_and_import(net, x) m = export_and_import(net, x)
# print(m) print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(x) res1 = net(x)
# res2 = sparse_jit(net, x) res2 = sparse_jit(net, x)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
# print("torch.mlir") print("torch.mlir")
# print(res2[0]) print(res2[0])
# print(res2[1]) print(res2[1])
# print(res2[2]) print(res2[2])
# print(res2[3]) print(res2[3])
# print(res2[4]) print(res2[4])
@run @run
@ -542,6 +567,8 @@ def test_sparse_activation():
# #
# CHECK: torch.sparse # CHECK: torch.sparse
# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) # CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
# CHECK: torch.mlir
# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.]
# #
def test_sparse_network(): def test_sparse_network():
def spike(input): def spike(input):
@ -607,15 +634,24 @@ def test_sparse_network():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(x) res1 = net(x)
# res2 = sparse_jit(net, x) res2 = sparse_jit(net, x)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
# print("torch.mlir") print("torch.mlir")
# print(res2) print(res2)
@run @run
# #
# CHECK-LABEL: test_sparse_feature_scaling
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
# ... more IR ...
# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"
# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]]
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
# CHECK: }
#
# CHECK: torch.sparse # CHECK: torch.sparse
# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], # CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851], # CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
@ -638,7 +674,7 @@ def test_sparse_feature_scaling():
torch.manual_seed(0) torch.manual_seed(0)
f = torch.rand(4, 4) f = torch.rand(4, 4)
m = export_and_import(net, f) m = export_and_import(net, f)
# print(m) print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(f) res1 = net(f)