[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648)

Now that the PyDev feature request pytorch/pytorch#117188 has been
completed, we can remove all the ad-hoc code that propagates sparsity
metadata and replace it with the built-int PyDev metadata for sparse
tensors. This removes a lot of code and also ensures sparsity is
consistent with the torch.sparse package for all cases.
pull/3614/merge
Aart Bik 2024-08-20 09:56:21 -07:00 committed by GitHub
parent 0a86deb59a
commit f72770a725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 190 deletions

View File

@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
) )
@dataclass(frozen=True) def sparsity_encoding(t: torch.Tensor) -> str:
class SparsityMeta: """Returns sparse tensor encoding for the given tensor as string."""
"""
Class for keeping track of sparsity meta data.
NOTE: this will be fully replaced by
torch.fx.passes.shape_prop.SparseTensorMetadata
"""
layout: torch.layout
batch_dim: int
sparse_dim: int
dense_dim: int
blocksize: Optional[Tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string."""
assert sparsity is not None
# Sparse tensors have the form # Sparse tensors have the form
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ] # [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
# which map directly to MLIR types. # which map directly to MLIR types.
batch_dim, sparse_dim, dense_dim = ( dim, batch_dim, sparse_dim, dense_dim = (
sparsity.batch_dim, t.ndim,
sparsity.sparse_dim, t.ndim - t.sparse_dim() - t.dense_dim(),
sparsity.dense_dim, t.sparse_dim(),
t.dense_dim(),
) )
dim = batch_dim + sparse_dim + dense_dim
assert dim == len(shape)
blocksize = sparsity.blocksize
dims = ",".join(f"d{d}" for d in range(dim)) dims = ",".join(f"d{d}" for d in range(dim))
if sparsity.layout is torch.sparse_coo: if t.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None assert sparse_dim >= 2
trail_dim = batch_dim + sparse_dim - 1 trail_dim = batch_dim + sparse_dim - 1
coords = ",".join( coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim) f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
) )
sep = "," if sparse_dim > 2 else "" sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr: idx_dtype = t._indices().dtype # supports uncoalesced COO tensors
assert sparse_dim == 2 and blocksize is None elif t.layout is torch.sparse_csr:
assert sparse_dim == 2
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
elif sparsity.layout is torch.sparse_csc: idx_dtype = t.col_indices().dtype
assert sparse_dim == 2 and blocksize is None elif t.layout is torch.sparse_csc:
assert sparse_dim == 2
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
idx_dtype = t.row_indices().dtype
else: else:
assert sparse_dim == 2 and blocksize is not None assert sparse_dim == 2
if sparsity.layout is torch.sparse_bsr: blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3]
if t.layout is torch.sparse_bsr:
i, j = batch_dim, batch_dim + 1 i, j = batch_dim, batch_dim + 1
idx_dtype = t.col_indices().dtype
else: else:
assert sparsity.layout is torch.sparse_bsc assert t.layout is torch.sparse_bsc
j, i = batch_dim, batch_dim + 1 j, i = batch_dim, batch_dim + 1
idx_dtype = t.row_indices().dtype
m, n = blocksize m, n = blocksize
lvls = ( lvls = (
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
lvls = f"{lvls},{dense}" lvls = f"{lvls},{dense}"
posw = torch.iinfo(sparsity.pos_dtype).bits posw = crdw = torch.iinfo(idx_dtype).bits
crdw = torch.iinfo(sparsity.crd_dtype).bits
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
@ -1043,20 +1026,27 @@ class ContextCache:
shape: torch.Size, shape: torch.Size,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
sparsity: Optional[SparsityMeta] = None, val: Optional[torch.Tensor] = None,
mutable: bool = False, mutable: bool = False,
): ):
"""Return IrType for !torch.vtensor with the given shape and dtype""" """Return IrType for !torch.vtensor with the given shape and dtype"""
stem = "torch.tensor" if mutable else "torch.vtensor" stem = "torch.tensor" if mutable else "torch.vtensor"
shape_asm = self.format_asm_shape(shape) shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype)) mlir_dtype = str(self.dtype_to_type(dtype))
if sparsity is not None: if val is not None and val.layout in [
encoding = sparsity_encoding(shape, sparsity) torch.sparse_coo,
assert encoding is not None torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]:
# This is a sparse tensor.
encoding = sparsity_encoding(val)
return IrType.parse( return IrType.parse(
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
context=self._c, context=self._c,
) )
# This is a dense tensor.
return IrType.parse( return IrType.parse(
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
) )
@ -1065,21 +1055,17 @@ class ContextCache:
try: try:
tensor_meta = node.meta.get("tensor_meta") tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val") val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
except KeyError as e: except KeyError as e:
raise RuntimeError( raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
) )
return self.value_info_to_type( return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable)
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
)
def value_info_to_type( def value_info_to_type(
self, self,
val, val,
*, *,
tensor_meta: Optional[TensorMetadata] = None, tensor_meta: Optional[TensorMetadata] = None,
sparsity=None,
mutable: bool = False, mutable: bool = False,
): ):
if tensor_meta is not None: if tensor_meta is not None:
@ -1097,14 +1083,14 @@ class ContextCache:
) )
else: else:
return self.tensor_metadata_to_type( return self.tensor_metadata_to_type(
tensor_meta, sparsity=sparsity, mutable=mutable tensor_meta, val=val, mutable=mutable
) )
elif val is not None: elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than # some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta # tensor_meta
if isinstance(val, TorchFakeTensor): if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type( return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable val.size(), val.dtype, val=val, mutable=mutable
) )
elif isinstance(val, list) and all( elif isinstance(val, list) and all(
isinstance(x, TorchFakeTensor) for x in val isinstance(x, TorchFakeTensor) for x in val
@ -1126,19 +1112,17 @@ class ContextCache:
self, self,
tm: TensorMetadata, tm: TensorMetadata,
*, *,
sparsity: Optional[SparsityMeta] = None, val: Optional[torch.Tensor] = None,
mutable: bool = False, mutable: bool = False,
) -> IrType: ) -> IrType:
tm_shape = tuple( tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape) item.node if is_symbolic(item) else item for item in list(tm.shape)
) )
key = (tm_shape, tm.dtype, sparsity, mutable) key = (tm_shape, tm.dtype, val, mutable)
t = self._tensor_metadata_cache.get(key) t = self._tensor_metadata_cache.get(key)
if t is None: if t is None:
t = self.get_vtensor_type( t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable)
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
)
self._tensor_metadata_cache[key] = t self._tensor_metadata_cache[key] = t
return t return t

View File

@ -0,0 +1,10 @@
config.unsupported = True
try:
import torch
if "2.5.0" <= str(torch.__version__):
print("Enabling sparsity propagation tests")
config.unsupported = False
except ModuleNotFoundError:
...

View File

@ -8,13 +8,11 @@
from typing import Any, Callable, Optional, Tuple, Dict from typing import Any, Callable, Optional, Tuple, Dict
import torch import torch
import torch.export
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras.fx_importer import SparsityMeta
from torch_mlir import ir from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d from torch_mlir.dialects import torch as torch_d
from torch_mlir.compiler_utils import run_pipeline_with_repro_report from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@ -23,139 +21,15 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
) )
# All sparse layouts currently supported in torch.sparse.
SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""
Returns a meta data tuple for the given sparse tensor.
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a._indices().dtype,
a._indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")
def sparse_export(
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.
But until issue
https://github.com/pytorch/pytorch/pull/117907
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, then annotating
sparse parameters with their actual sparse layout attributes,
followed by some simple propagation rules. This temporary solution
accelerates testing torch-mlir with PyTorch sparse tensors until
the issue is resolved upstream.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder":
# Argument.
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog
def export_and_import(f, *args, **kwargs): def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials.""" """A FX graph importer, stripped down to essentials."""
context = ir.Context() context = ir.Context()
torch_d.register_dialect(context) torch_d.register_dialect(context)
fx_importer = FxImporter(context=context) fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs) prog = torch.export.export(f, args, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
fx_importer.import_frozen_program(prog) fx_importer.import_frozen_program(prog)
return fx_importer.module return fx_importer.module
@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs):
enable_ir_printing=False, enable_ir_printing=False,
) )
# Compile with reference Linalg backend. # Compile with reference Linalg backend.
# TODO: runtime verification currently fails with 'rank mismatch' on # TODO: runtime verification ails with 'rank mismatch' on memref.cast
# memref.cast. Need to fix the IR first.
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False) backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
compiled = backend.compile(module) compiled = backend.compile(module)
invoker = backend.load(compiled) invoker = backend.load(compiled)
@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs):
def run(f): def run(f):
print(f"{f.__name__}") # Prompt test name and torch version (for debugging).
print(f"{f.__name__} ({torch.__version__})")
print("-" * len(f.__name__)) print("-" * len(f.__name__))
f() f()
print() print()