mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] implement first sparse_jit end-to-end path (#2894)
This PR introduces a sparse_jit wrapper that can run simple models with sparse tensor inputs end-to-end. The implementation shows all required components on modifying sparse tensor types with a 1:N relation on the call sites. Two tests shows that the JIT runs end-to-end while computing the correct results. More details to follow (generalizing to COO and different ranks, as well as support for *output* sparse tensors), but the general concepts are all here now. **_Update: Thanks to Rob, bump to proper LLVM/MLIR hash is done!_** _**NOTE that all parameter passing changes are nicely done "downstream" in MLIR, so very little changes are required in torch-mlir code proper**_ --------- Co-authored-by: Franz Haniel <77495327+frafranz@users.noreply.github.com> Co-authored-by: Franz Haniel <franz.haniel@amd.com>pull/2900/head
parent
bfb93cb99f
commit
be8375d350
|
@ -136,6 +136,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
"convert-shape-to-std",
|
||||
# MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
|
||||
# to ensure operations on sparse tensors are lowered to loops.
|
||||
"sparse-assembler",
|
||||
"sparsification-and-bufferization",
|
||||
"sparse-storage-specifier-to-llvm",
|
||||
# Bufferize.
|
||||
|
|
|
@ -208,7 +208,9 @@ SYMBOLIC_OP_TO_TORCH_OP = {
|
|||
}
|
||||
|
||||
|
||||
def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str:
|
||||
def sparsity_encoding(
|
||||
shape: torch.Size, sparsity: tuple[torch.layout, int, int]
|
||||
) -> str:
|
||||
"""Returns sparse tensor encoding for the given sparse layout as string.
|
||||
|
||||
The method currently just supports 2-dim sparse formats. This should be
|
||||
|
@ -216,20 +218,24 @@ def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str:
|
|||
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
|
||||
is currently implememented in torch.sparse, this should not a be problem.
|
||||
"""
|
||||
assert sparsity is not None
|
||||
sparse_layout, posw, crdw = sparsity
|
||||
|
||||
# TODO: any rank
|
||||
if len(shape) != 2:
|
||||
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
|
||||
|
||||
if sparse_layout is torch.sparse_coo:
|
||||
return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>"
|
||||
if sparse_layout is torch.sparse_csr:
|
||||
return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>"
|
||||
if sparse_layout is torch.sparse_csc:
|
||||
return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>"
|
||||
# TODO: block format (derive block size!)
|
||||
smap = f"(i,j)->(i:compressed(nonunique),j:singleton)"
|
||||
elif sparse_layout is torch.sparse_csr:
|
||||
smap = f"(i,j)->(i:dense,j:compressed)"
|
||||
elif sparse_layout is torch.sparse_csc:
|
||||
smap = f"(i,j)->(j:dense,i:compressed)"
|
||||
else:
|
||||
# TODO: block format (derive block size!)
|
||||
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
|
||||
|
||||
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
|
||||
return f"#sparse_tensor.encoding<{{map={smap},posWidth={posw},crdWidth={crdw}}}>"
|
||||
|
||||
|
||||
def is_symbolic(obj: Any) -> bool:
|
||||
|
@ -479,14 +485,19 @@ class ContextCache:
|
|||
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
||||
|
||||
def get_vtensor_type(
|
||||
self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None
|
||||
self,
|
||||
shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only
|
||||
):
|
||||
shape_asm = self.format_asm_shape(shape)
|
||||
mlir_dtype = str(self.dtype_to_type(dtype))
|
||||
if sparse_layout is not None:
|
||||
sparsity = sparsity_encoding(shape, sparse_layout)
|
||||
if sparsity is not None:
|
||||
encoding = sparsity_encoding(shape, sparsity)
|
||||
assert encoding is not None
|
||||
return IrType.parse(
|
||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>",
|
||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
|
||||
context=self._c,
|
||||
)
|
||||
return IrType.parse(
|
||||
|
@ -497,7 +508,7 @@ class ContextCache:
|
|||
try:
|
||||
tensor_meta = node.meta.get("tensor_meta")
|
||||
val = node.meta.get("val")
|
||||
sparse_layout = node.meta.get("sparsity", None)
|
||||
sparsity = node.meta.get("sparsity", None)
|
||||
if tensor_meta is not None:
|
||||
assert isinstance(tensor_meta, TensorMetadata)
|
||||
# Quantized tensor meta data is not preserved in our lowering,
|
||||
|
@ -507,12 +518,14 @@ class ContextCache:
|
|||
f"Quantized tensor meta data is not supported."
|
||||
)
|
||||
else:
|
||||
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
|
||||
return self.tensor_metadata_to_type(tensor_meta, sparsity=sparsity)
|
||||
elif val is not None:
|
||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||
# tensor_meta
|
||||
if isinstance(val, TorchFakeTensor):
|
||||
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)
|
||||
return self.get_vtensor_type(
|
||||
val.size(), val.dtype, sparsity=sparsity
|
||||
)
|
||||
|
||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||
if t is not None:
|
||||
|
@ -528,16 +541,19 @@ class ContextCache:
|
|||
)
|
||||
|
||||
def tensor_metadata_to_type(
|
||||
self, tm: TensorMetadata, sparse_layout: torch.layout = None
|
||||
self,
|
||||
tm: TensorMetadata,
|
||||
*,
|
||||
sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only
|
||||
) -> IrType:
|
||||
tm_shape = tuple(
|
||||
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
||||
)
|
||||
|
||||
key = (tm_shape, tm.dtype, sparse_layout)
|
||||
key = (tm_shape, tm.dtype, sparsity)
|
||||
t = self._tensor_metadata_cache.get(key)
|
||||
if t is None:
|
||||
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
|
||||
t = self.get_vtensor_type(tm.shape, tm.dtype, sparsity=sparsity)
|
||||
self._tensor_metadata_cache[key] = t
|
||||
return t
|
||||
|
||||
|
@ -1128,7 +1144,8 @@ class TypeSubclassMap:
|
|||
|
||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||
# may have a different meaning.
|
||||
class EmptyType: ...
|
||||
class EmptyType:
|
||||
...
|
||||
|
||||
|
||||
Empty = EmptyType()
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
|
@ -14,6 +14,10 @@ import torch.nn as nn
|
|||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
|
||||
# All sparse layouts currently supported in torch.sparse.
|
||||
|
@ -22,13 +26,50 @@ SPARSE_LAYOUTS = [
|
|||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc
|
||||
torch.sparse_bsc,
|
||||
]
|
||||
|
||||
|
||||
def sparse_export(f: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
|
||||
def sparse_overhead_width(d: torch.dtype) -> int:
|
||||
"""Returns bit-width for admissible overhead type."""
|
||||
if d is torch.int64:
|
||||
return 64
|
||||
if d is torch.int32:
|
||||
return 32
|
||||
if d is torch.int16:
|
||||
return 16
|
||||
if d is torch.int8:
|
||||
return 8
|
||||
raise RuntimeError(f"Unsupported overhead type {d}")
|
||||
|
||||
|
||||
def sparse_metadata(a: torch.Tensor) -> tuple[torch.layout, int, int]:
|
||||
"""Returns a meta data tuple for the given sparse tensor."""
|
||||
if a.layout is torch.sparse_coo:
|
||||
return (
|
||||
a.layout,
|
||||
sparse_overhead_width(a.indices().dtype),
|
||||
sparse_overhead_width(a.indices().dtype),
|
||||
)
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
return (
|
||||
a.layout,
|
||||
sparse_overhead_width(a.crow_indices().dtype),
|
||||
sparse_overhead_width(a.col_indices().dtype),
|
||||
)
|
||||
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
||||
return (
|
||||
a.layout,
|
||||
sparse_overhead_width(a.ccol_indices().dtype),
|
||||
sparse_overhead_width(a.row_indices().dtype),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported sparse layout for {a}")
|
||||
|
||||
|
||||
def sparse_export(
|
||||
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
|
||||
) -> torch.export.ExportedProgram:
|
||||
"""
|
||||
This is a ***temporary*** wrapper around `torch.export.export`
|
||||
that eventually should be removed and simply replaced by the
|
||||
|
@ -47,17 +88,16 @@ def sparse_export(f: Callable,
|
|||
resovled.
|
||||
"""
|
||||
# Convert all arguments to dense.
|
||||
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
|
||||
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
|
||||
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
|
||||
mask = [a.layout in SPARSE_LAYOUTS for a in args]
|
||||
# Build the regular FX traced graph with only dense arguments
|
||||
# (the current version would crash otherwise, see issue above).
|
||||
prog = torch.export.export(f, dargs, kwargs, constraints=None)
|
||||
# Annotate sparse arguments in the graph.
|
||||
alen = len(args)
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
if node.op == "placeholder" and i < alen and mask[i]:
|
||||
node.meta['sparsity'] = args[i].layout
|
||||
# TODO: annotate inputs to change calling conventions!
|
||||
if node.op == "placeholder" and i < alen and mask[i]:
|
||||
node.meta["sparsity"] = sparse_metadata(args[i])
|
||||
return prog
|
||||
|
||||
|
||||
|
@ -68,7 +108,46 @@ def export_and_import(f, *args, **kwargs):
|
|||
fx_importer = FxImporter(context=context)
|
||||
prog = sparse_export(f, args, kwargs)
|
||||
fx_importer.import_frozen_exported_program(prog)
|
||||
return fx_importer.module_op
|
||||
return fx_importer.module
|
||||
|
||||
|
||||
def sparse_jit(f, *args, **kwargs):
|
||||
"""This method compiles and runs the given callable using linalg backend."""
|
||||
# Import module and lower into Linalg IR.
|
||||
module = export_and_import(f, *args, *kwargs)
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering TorchFX IR -> Linalg IR",
|
||||
enable_ir_printing=False,
|
||||
)
|
||||
# Compile with reference Linalg backend.
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(module)
|
||||
invoker = backend.load(compiled)
|
||||
# Prepare input parameters. Sparse input tensors are split into
|
||||
# their composite tensors. All PyTorch tensors are converted
|
||||
# to their backing numpy arrays.
|
||||
#
|
||||
# TODO: sparse output tensors
|
||||
#
|
||||
xargs = []
|
||||
for a in args:
|
||||
if a.layout is torch.sparse_coo:
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a.indices().numpy())
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a.crow_indices().numpy())
|
||||
xargs.append(a.col_indices().numpy())
|
||||
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
||||
xargs.append(a.values().numpy())
|
||||
xargs.append(a.ccol_indices().numpy())
|
||||
xargs.append(a.row_indices().numpy())
|
||||
else:
|
||||
xargs.append(a.numpy())
|
||||
# Invoke.
|
||||
return invoker.main(*xargs)
|
||||
|
||||
|
||||
def run(f):
|
||||
|
@ -80,51 +159,77 @@ def run(f):
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_sum
|
||||
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
|
||||
# CHECK: %[[N:.*]] = torch.constant.none
|
||||
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse = tensor(4096.)
|
||||
# CHECK: torch.mlir = 4096.0
|
||||
#
|
||||
def test_sparse_sum():
|
||||
|
||||
class SumNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SumNet, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
|
||||
dense_input = torch.ones(64, 64)
|
||||
dense_input = torch.ones(64, 64)
|
||||
sparse_input = dense_input.to_sparse_csr()
|
||||
m = export_and_import(SumNet(), sparse_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
net = SumNet()
|
||||
res1 = net(sparse_input)
|
||||
res2 = sparse_jit(net, sparse_input)
|
||||
print("torch.sparse =", res1)
|
||||
print("torch.mlir =", res2)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_SpMM
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
|
||||
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
|
||||
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
|
||||
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
|
||||
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
|
||||
#
|
||||
def test_sparse_SpMM():
|
||||
|
||||
class MatMulNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MatMulNet, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.matmul(x, y)
|
||||
return torch.matmul(x, y)
|
||||
|
||||
|
||||
dense_input = torch.ones(64, 64)
|
||||
dense_input = torch.ones(8, 8)
|
||||
sparse_input = dense_input.to_sparse_coo()
|
||||
m = export_and_import(MatMulNet(), sparse_input, dense_input)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
# TODO: run with COO, right now only CSR works
|
||||
sparse_input = dense_input.to_sparse_csr()
|
||||
net = MatMulNet()
|
||||
res1 = net(sparse_input, dense_input)
|
||||
res2 = sparse_jit(net, sparse_input, dense_input)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
|
Loading…
Reference in New Issue