mirror of https://github.com/llvm/torch-mlir
[sparse] fix double free due to incompatibility between buffer-deallo… (#3303)
…cation and sparse tensors. **NOTE**: This PR _doges_ the issue in buffer-deallocation pass instead of resolving it. In the future, we need to fix the bug in buffer-deallocation pass when handling code generated by sparse compiler.pull/3308/head
parent
5213557b87
commit
cff144b3ac
|
@ -155,7 +155,8 @@ LOWERING_PIPELINE = (
|
|||
"sparse-assembler{direct-out}",
|
||||
"sparsification-and-bufferization",
|
||||
"sparse-storage-specifier-to-llvm",
|
||||
"inline", # inline sparse helper methods where useful
|
||||
# Buffer deallocation pass does not know how to handle realloc.
|
||||
"func.func(expand-realloc)",
|
||||
# Bufferize.
|
||||
"func.func(scf-bufferize)",
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
|
@ -167,6 +168,9 @@ LOWERING_PIPELINE = (
|
|||
"func.func(tensor-bufferize)",
|
||||
"func.func(finalizing-bufferize)",
|
||||
"func.func(buffer-deallocation)",
|
||||
# Buffer-deallocation does not work with the inlined code generated
|
||||
# by sparse tensor dialect.
|
||||
"inline", # inline sparse helper methods where useful
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||
# of unranked memref, and we rewrite the return to actually be a
|
||||
|
@ -180,7 +184,6 @@ LOWERING_PIPELINE = (
|
|||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
"func.func(convert-linalg-to-loops)",
|
||||
"func.func(expand-realloc)",
|
||||
"func.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
"func.func(refback-expand-ops-for-llvm)",
|
||||
|
|
|
@ -364,26 +364,30 @@ def test_sparse_SpMM():
|
|||
# CHECK-LABEL: test_sparse_eltwise
|
||||
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
|
||||
# CHECK: }
|
||||
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
|
||||
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
|
||||
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||
# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
|
||||
# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
|
||||
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
||||
# ...
|
||||
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
|
||||
# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
|
||||
# CHECK: layout=torch.sparse_csr)
|
||||
#
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: [0 2 4 6 8]
|
||||
# CHECK: [0 1 0 1 0 1 0 1]
|
||||
# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
|
||||
# CHECK: -15. -16.]
|
||||
# CHECK: torch.mlir.batch
|
||||
#
|
||||
def test_sparse_eltwise():
|
||||
|
@ -396,7 +400,7 @@ def test_sparse_eltwise():
|
|||
|
||||
net = EltNet()
|
||||
dense_input = torch.reshape(
|
||||
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
|
||||
torch.arange(1, 17, dtype=torch.float32), shape=(4, 2, 2)
|
||||
)
|
||||
|
||||
# This yields a plain CSR with dense **sub**tensor
|
||||
|
@ -411,12 +415,15 @@ def test_sparse_eltwise():
|
|||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
# TODO: make these work
|
||||
# res2 = sparse_jit(net, sparse_input)
|
||||
# res3 = sparse_jit(net, batch_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2[0])
|
||||
print(res2[1])
|
||||
print(res2[2])
|
||||
print("torch.mlir.batch")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue