[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648)

Now that the PyDev feature request pytorch/pytorch#117188 has been
completed, we can remove all the ad-hoc code that propagates sparsity
metadata and replace it with the built-int PyDev metadata for sparse
tensors. This removes a lot of code and also ensures sparsity is
consistent with the torch.sparse package for all cases.
pull/3614/merge
Aart Bik 2024-08-20 09:56:21 -07:00 committed by GitHub
parent 0a86deb59a
commit f72770a725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 190 deletions

View File

@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
)
@dataclass(frozen=True)
class SparsityMeta:
"""
Class for keeping track of sparsity meta data.
NOTE: this will be fully replaced by
torch.fx.passes.shape_prop.SparseTensorMetadata
"""
layout: torch.layout
batch_dim: int
sparse_dim: int
dense_dim: int
blocksize: Optional[Tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string."""
assert sparsity is not None
def sparsity_encoding(t: torch.Tensor) -> str:
"""Returns sparse tensor encoding for the given tensor as string."""
# Sparse tensors have the form
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
# which map directly to MLIR types.
batch_dim, sparse_dim, dense_dim = (
sparsity.batch_dim,
sparsity.sparse_dim,
sparsity.dense_dim,
dim, batch_dim, sparse_dim, dense_dim = (
t.ndim,
t.ndim - t.sparse_dim() - t.dense_dim(),
t.sparse_dim(),
t.dense_dim(),
)
dim = batch_dim + sparse_dim + dense_dim
assert dim == len(shape)
blocksize = sparsity.blocksize
dims = ",".join(f"d{d}" for d in range(dim))
if sparsity.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None
if t.layout is torch.sparse_coo:
assert sparse_dim >= 2
trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
)
sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
assert sparse_dim == 2 and blocksize is None
idx_dtype = t._indices().dtype # supports uncoalesced COO tensors
elif t.layout is torch.sparse_csr:
assert sparse_dim == 2
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
elif sparsity.layout is torch.sparse_csc:
assert sparse_dim == 2 and blocksize is None
idx_dtype = t.col_indices().dtype
elif t.layout is torch.sparse_csc:
assert sparse_dim == 2
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
idx_dtype = t.row_indices().dtype
else:
assert sparse_dim == 2 and blocksize is not None
if sparsity.layout is torch.sparse_bsr:
assert sparse_dim == 2
blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3]
if t.layout is torch.sparse_bsr:
i, j = batch_dim, batch_dim + 1
idx_dtype = t.col_indices().dtype
else:
assert sparsity.layout is torch.sparse_bsc
assert t.layout is torch.sparse_bsc
j, i = batch_dim, batch_dim + 1
idx_dtype = t.row_indices().dtype
m, n = blocksize
lvls = (
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
lvls = f"{lvls},{dense}"
posw = torch.iinfo(sparsity.pos_dtype).bits
crdw = torch.iinfo(sparsity.crd_dtype).bits
posw = crdw = torch.iinfo(idx_dtype).bits
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
@ -1043,20 +1026,27 @@ class ContextCache:
shape: torch.Size,
dtype: torch.dtype,
*,
sparsity: Optional[SparsityMeta] = None,
val: Optional[torch.Tensor] = None,
mutable: bool = False,
):
"""Return IrType for !torch.vtensor with the given shape and dtype"""
stem = "torch.tensor" if mutable else "torch.vtensor"
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparsity is not None:
encoding = sparsity_encoding(shape, sparsity)
assert encoding is not None
if val is not None and val.layout in [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]:
# This is a sparse tensor.
encoding = sparsity_encoding(val)
return IrType.parse(
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
context=self._c,
)
# This is a dense tensor.
return IrType.parse(
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
)
@ -1065,21 +1055,17 @@ class ContextCache:
try:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)
return self.value_info_to_type(
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
)
return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable)
def value_info_to_type(
self,
val,
*,
tensor_meta: Optional[TensorMetadata] = None,
sparsity=None,
mutable: bool = False,
):
if tensor_meta is not None:
@ -1097,14 +1083,14 @@ class ContextCache:
)
else:
return self.tensor_metadata_to_type(
tensor_meta, sparsity=sparsity, mutable=mutable
tensor_meta, val=val, mutable=mutable
)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
val.size(), val.dtype, val=val, mutable=mutable
)
elif isinstance(val, list) and all(
isinstance(x, TorchFakeTensor) for x in val
@ -1126,19 +1112,17 @@ class ContextCache:
self,
tm: TensorMetadata,
*,
sparsity: Optional[SparsityMeta] = None,
val: Optional[torch.Tensor] = None,
mutable: bool = False,
) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)
key = (tm_shape, tm.dtype, sparsity, mutable)
key = (tm_shape, tm.dtype, val, mutable)
t = self._tensor_metadata_cache.get(key)
if t is None:
t = self.get_vtensor_type(
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
)
t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable)
self._tensor_metadata_cache[key] = t
return t

View File

@ -0,0 +1,10 @@
config.unsupported = True
try:
import torch
if "2.5.0" <= str(torch.__version__):
print("Enabling sparsity propagation tests")
config.unsupported = False
except ModuleNotFoundError:
...

View File

@ -8,13 +8,11 @@
from typing import Any, Callable, Optional, Tuple, Dict
import torch
import torch.export
import torch.nn as nn
import numpy as np
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras.fx_importer import SparsityMeta
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@ -23,139 +21,15 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
)
# All sparse layouts currently supported in torch.sparse.
SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""
Returns a meta data tuple for the given sparse tensor.
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a._indices().dtype,
a._indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")
def sparse_export(
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.
But until issue
https://github.com/pytorch/pytorch/pull/117907
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, then annotating
sparse parameters with their actual sparse layout attributes,
followed by some simple propagation rules. This temporary solution
accelerates testing torch-mlir with PyTorch sparse tensors until
the issue is resolved upstream.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder":
# Argument.
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog
def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
"""A FX graph importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
fx_importer.import_frozen_program(prog)
return fx_importer.module
@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs):
enable_ir_printing=False,
)
# Compile with reference Linalg backend.
# TODO: runtime verification currently fails with 'rank mismatch' on
# memref.cast. Need to fix the IR first.
# TODO: runtime verification ails with 'rank mismatch' on memref.cast
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
compiled = backend.compile(module)
invoker = backend.load(compiled)
@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs):
def run(f):
print(f"{f.__name__}")
# Prompt test name and torch version (for debugging).
print(f"{f.__name__} ({torch.__version__})")
print("-" * len(f.__name__))
f()
print()