[torch-mlir][sparse] implement first sparse_jit end-to-end path (#2894)

This PR introduces a sparse_jit wrapper that can run simple models with
sparse tensor inputs end-to-end. The implementation shows all required
components on modifying sparse tensor types with a 1:N relation on the
call sites. Two tests shows that the JIT runs end-to-end while computing
the correct results.

More details to follow (generalizing to COO and different ranks, as well
as support for *output* sparse tensors), but the general concepts are
all here now.

**_Update: Thanks to Rob, bump to proper LLVM/MLIR hash is done!_**

_**NOTE that all parameter passing changes are nicely done "downstream"
in MLIR, so very little changes are required in torch-mlir code
proper**_

---------

Co-authored-by: Franz Haniel <77495327+frafranz@users.noreply.github.com>
Co-authored-by: Franz Haniel <franz.haniel@amd.com>
pull/2900/head
Aart Bik 2024-02-12 10:04:54 -08:00 committed by GitHub
parent bfb93cb99f
commit be8375d350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 168 additions and 45 deletions

View File

@ -136,6 +136,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
"convert-shape-to-std",
# MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
# to ensure operations on sparse tensors are lowered to loops.
"sparse-assembler",
"sparsification-and-bufferization",
"sparse-storage-specifier-to-llvm",
# Bufferize.

View File

@ -208,7 +208,9 @@ SYMBOLIC_OP_TO_TORCH_OP = {
}
def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str:
def sparsity_encoding(
shape: torch.Size, sparsity: tuple[torch.layout, int, int]
) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string.
The method currently just supports 2-dim sparse formats. This should be
@ -216,20 +218,24 @@ def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str:
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
is currently implememented in torch.sparse, this should not a be problem.
"""
assert sparsity is not None
sparse_layout, posw, crdw = sparsity
# TODO: any rank
if len(shape) != 2:
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
if sparse_layout is torch.sparse_coo:
return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>"
if sparse_layout is torch.sparse_csr:
return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>"
if sparse_layout is torch.sparse_csc:
return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>"
# TODO: block format (derive block size!)
smap = f"(i,j)->(i:compressed(nonunique),j:singleton)"
elif sparse_layout is torch.sparse_csr:
smap = f"(i,j)->(i:dense,j:compressed)"
elif sparse_layout is torch.sparse_csc:
smap = f"(i,j)->(j:dense,i:compressed)"
else:
# TODO: block format (derive block size!)
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
return f"#sparse_tensor.encoding<{{map={smap},posWidth={posw},crdWidth={crdw}}}>"
def is_symbolic(obj: Any) -> bool:
@ -479,14 +485,19 @@ class ContextCache:
"""Return IrType for !torch.vtensor with the given shape and dtype"""
def get_vtensor_type(
self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None
self,
shape: torch.Size,
dtype: torch.dtype,
*,
sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only
):
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparse_layout is not None:
sparsity = sparsity_encoding(shape, sparse_layout)
if sparsity is not None:
encoding = sparsity_encoding(shape, sparsity)
assert encoding is not None
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>",
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
context=self._c,
)
return IrType.parse(
@ -497,7 +508,7 @@ class ContextCache:
try:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparse_layout = node.meta.get("sparsity", None)
sparsity = node.meta.get("sparsity", None)
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
@ -507,12 +518,14 @@ class ContextCache:
f"Quantized tensor meta data is not supported."
)
else:
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
return self.tensor_metadata_to_type(tensor_meta, sparsity=sparsity)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity
)
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
@ -528,16 +541,19 @@ class ContextCache:
)
def tensor_metadata_to_type(
self, tm: TensorMetadata, sparse_layout: torch.layout = None
self,
tm: TensorMetadata,
*,
sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only
) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)
key = (tm_shape, tm.dtype, sparse_layout)
key = (tm_shape, tm.dtype, sparsity)
t = self._tensor_metadata_cache.get(key)
if t is None:
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
t = self.get_vtensor_type(tm.shape, tm.dtype, sparsity=sparsity)
self._tensor_metadata_cache[key] = t
return t
@ -1128,7 +1144,8 @@ class TypeSubclassMap:
# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType: ...
class EmptyType:
...
Empty = EmptyType()

View File

@ -5,7 +5,7 @@
# RUN: %PYTHON %s | FileCheck %s
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.export
@ -14,6 +14,10 @@ import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
# All sparse layouts currently supported in torch.sparse.
@ -22,13 +26,50 @@ SPARSE_LAYOUTS = [
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc
torch.sparse_bsc,
]
def sparse_export(f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
def sparse_overhead_width(d: torch.dtype) -> int:
"""Returns bit-width for admissible overhead type."""
if d is torch.int64:
return 64
if d is torch.int32:
return 32
if d is torch.int16:
return 16
if d is torch.int8:
return 8
raise RuntimeError(f"Unsupported overhead type {d}")
def sparse_metadata(a: torch.Tensor) -> tuple[torch.layout, int, int]:
"""Returns a meta data tuple for the given sparse tensor."""
if a.layout is torch.sparse_coo:
return (
a.layout,
sparse_overhead_width(a.indices().dtype),
sparse_overhead_width(a.indices().dtype),
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
return (
a.layout,
sparse_overhead_width(a.crow_indices().dtype),
sparse_overhead_width(a.col_indices().dtype),
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
return (
a.layout,
sparse_overhead_width(a.ccol_indices().dtype),
sparse_overhead_width(a.row_indices().dtype),
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")
def sparse_export(
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
@ -47,17 +88,16 @@ def sparse_export(f: Callable,
resovled.
"""
# Convert all arguments to dense.
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs, constraints=None)
# Annotate sparse arguments in the graph.
alen = len(args)
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder" and i < alen and mask[i]:
node.meta['sparsity'] = args[i].layout
# TODO: annotate inputs to change calling conventions!
if node.op == "placeholder" and i < alen and mask[i]:
node.meta["sparsity"] = sparse_metadata(args[i])
return prog
@ -68,7 +108,46 @@ def export_and_import(f, *args, **kwargs):
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
return fx_importer.module
def sparse_jit(f, *args, **kwargs):
"""This method compiles and runs the given callable using linalg backend."""
# Import module and lower into Linalg IR.
module = export_and_import(f, *args, *kwargs)
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
)
# Compile with reference Linalg backend.
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
invoker = backend.load(compiled)
# Prepare input parameters. Sparse input tensors are split into
# their composite tensors. All PyTorch tensors are converted
# to their backing numpy arrays.
#
# TODO: sparse output tensors
#
xargs = []
for a in args:
if a.layout is torch.sparse_coo:
xargs.append(a.values().numpy())
xargs.append(a.indices().numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.values().numpy())
xargs.append(a.crow_indices().numpy())
xargs.append(a.col_indices().numpy())
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
xargs.append(a.values().numpy())
xargs.append(a.ccol_indices().numpy())
xargs.append(a.row_indices().numpy())
else:
xargs.append(a.numpy())
# Invoke.
return invoker.main(*xargs)
def run(f):
@ -80,51 +159,77 @@ def run(f):
@run
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
# CHECK: %[[N:.*]] = torch.constant.none
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
# CHECK: }
#
# CHECK: torch.sparse = tensor(4096.)
# CHECK: torch.mlir = 4096.0
#
def test_sparse_sum():
class SumNet(torch.nn.Module):
def __init__(self):
super(SumNet, self).__init__()
def forward(self, x):
return x.sum()
dense_input = torch.ones(64, 64)
dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_csr()
m = export_and_import(SumNet(), sparse_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
net = SumNet()
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
print("torch.sparse =", res1)
print("torch.mlir =", res2)
@run
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> {
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32>
# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
# CHECK: torch.mlir
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
#
def test_sparse_SpMM():
class MatMulNet(torch.nn.Module):
def __init__(self):
super(MatMulNet, self).__init__()
def forward(self, x, y):
return torch.matmul(x, y)
return torch.matmul(x, y)
dense_input = torch.ones(64, 64)
dense_input = torch.ones(8, 8)
sparse_input = dense_input.to_sparse_coo()
m = export_and_import(MatMulNet(), sparse_input, dense_input)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# TODO: run with COO, right now only CSR works
sparse_input = dense_input.to_sparse_csr()
net = MatMulNet()
res1 = net(sparse_input, dense_input)
res2 = sparse_jit(net, sparse_input, dense_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)