[sparse] fix double free due to incompatibility between buffer-deallo… (#3303)

…cation and sparse tensors.

**NOTE**: This PR _doges_ the issue in buffer-deallocation pass instead
of resolving it. In the future, we need to fix the bug in
buffer-deallocation pass when handling code generated by sparse
compiler.
pull/3308/head
Peiming Liu 2024-05-08 21:18:17 -07:00 committed by GitHub
parent 5213557b87
commit cff144b3ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 14 deletions

View File

@ -155,7 +155,8 @@ LOWERING_PIPELINE = (
"sparse-assembler{direct-out}", "sparse-assembler{direct-out}",
"sparsification-and-bufferization", "sparsification-and-bufferization",
"sparse-storage-specifier-to-llvm", "sparse-storage-specifier-to-llvm",
"inline", # inline sparse helper methods where useful # Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
# Bufferize. # Bufferize.
"func.func(scf-bufferize)", "func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)", "func.func(tm-tensor-bufferize)",
@ -167,6 +168,9 @@ LOWERING_PIPELINE = (
"func.func(tensor-bufferize)", "func.func(tensor-bufferize)",
"func.func(finalizing-bufferize)", "func.func(finalizing-bufferize)",
"func.func(buffer-deallocation)", "func.func(buffer-deallocation)",
# Buffer-deallocation does not work with the inlined code generated
# by sparse tensor dialect.
"inline", # inline sparse helper methods where useful
# Munge to make it ExecutionEngine compatible. # Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms # Specifically, we rewrite calling convention boundaries to be in terms
# of unranked memref, and we rewrite the return to actually be a # of unranked memref, and we rewrite the return to actually be a
@ -180,7 +184,6 @@ LOWERING_PIPELINE = (
"func.func(tm-tensor-to-loops)", "func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)", "func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)", "func.func(convert-linalg-to-loops)",
"func.func(expand-realloc)",
"func.func(lower-affine)", "func.func(lower-affine)",
"convert-scf-to-cf", "convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)", "func.func(refback-expand-ops-for-llvm)",

View File

@ -364,26 +364,30 @@ def test_sparse_SpMM():
# CHECK-LABEL: test_sparse_eltwise # CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main( # CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> { # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> # CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> # CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
# CHECK: } # CHECK: }
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main( # CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> { # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> # CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> # CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
# CHECK: } # CHECK: }
# #
# CHECK: torch.sparse # CHECK: torch.sparse
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), # CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, # CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
# CHECK: values=tensor({{\[}}[ -1., -2.], # CHECK: values=tensor({{\[}}[ -1., -2.],
# ... # ...
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, # CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
# CHECK: layout=torch.sparse_csr) # CHECK: layout=torch.sparse_csr)
#
# CHECK: torch.mlir # CHECK: torch.mlir
# CHECK: [0 2 4 6 8]
# CHECK: [0 1 0 1 0 1 0 1]
# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
# CHECK: -15. -16.]
# CHECK: torch.mlir.batch # CHECK: torch.mlir.batch
# #
def test_sparse_eltwise(): def test_sparse_eltwise():
@ -396,7 +400,7 @@ def test_sparse_eltwise():
net = EltNet() net = EltNet()
dense_input = torch.reshape( dense_input = torch.reshape(
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) torch.arange(1, 17, dtype=torch.float32), shape=(4, 2, 2)
) )
# This yields a plain CSR with dense **sub**tensor # This yields a plain CSR with dense **sub**tensor
@ -411,12 +415,15 @@ def test_sparse_eltwise():
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input) res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
# TODO: make these work # TODO: make these work
# res2 = sparse_jit(net, sparse_input)
# res3 = sparse_jit(net, batch_input) # res3 = sparse_jit(net, batch_input)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
print("torch.mlir") print("torch.mlir")
print(res2[0])
print(res2[1])
print(res2[2])
print("torch.mlir.batch") print("torch.mlir.batch")