mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648)
Now that the PyDev feature request pytorch/pytorch#117188 has been completed, we can remove all the ad-hoc code that propagates sparsity metadata and replace it with the built-int PyDev metadata for sparse tensors. This removes a lot of code and also ensures sparsity is consistent with the torch.sparse package for all cases.pull/3614/merge
parent
0a86deb59a
commit
f72770a725
|
@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
|
|||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SparsityMeta:
|
||||
"""
|
||||
Class for keeping track of sparsity meta data.
|
||||
|
||||
NOTE: this will be fully replaced by
|
||||
torch.fx.passes.shape_prop.SparseTensorMetadata
|
||||
"""
|
||||
|
||||
layout: torch.layout
|
||||
batch_dim: int
|
||||
sparse_dim: int
|
||||
dense_dim: int
|
||||
blocksize: Optional[Tuple[int, int]]
|
||||
pos_dtype: torch.dtype
|
||||
crd_dtype: torch.dtype
|
||||
|
||||
|
||||
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
||||
"""Returns sparse tensor encoding for the given sparse layout as string."""
|
||||
assert sparsity is not None
|
||||
def sparsity_encoding(t: torch.Tensor) -> str:
|
||||
"""Returns sparse tensor encoding for the given tensor as string."""
|
||||
|
||||
# Sparse tensors have the form
|
||||
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
|
||||
# which map directly to MLIR types.
|
||||
batch_dim, sparse_dim, dense_dim = (
|
||||
sparsity.batch_dim,
|
||||
sparsity.sparse_dim,
|
||||
sparsity.dense_dim,
|
||||
dim, batch_dim, sparse_dim, dense_dim = (
|
||||
t.ndim,
|
||||
t.ndim - t.sparse_dim() - t.dense_dim(),
|
||||
t.sparse_dim(),
|
||||
t.dense_dim(),
|
||||
)
|
||||
dim = batch_dim + sparse_dim + dense_dim
|
||||
assert dim == len(shape)
|
||||
blocksize = sparsity.blocksize
|
||||
|
||||
dims = ",".join(f"d{d}" for d in range(dim))
|
||||
|
||||
if sparsity.layout is torch.sparse_coo:
|
||||
assert sparse_dim >= 2 and blocksize is None
|
||||
if t.layout is torch.sparse_coo:
|
||||
assert sparse_dim >= 2
|
||||
trail_dim = batch_dim + sparse_dim - 1
|
||||
coords = ",".join(
|
||||
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
|
||||
)
|
||||
sep = "," if sparse_dim > 2 else ""
|
||||
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
|
||||
elif sparsity.layout is torch.sparse_csr:
|
||||
assert sparse_dim == 2 and blocksize is None
|
||||
idx_dtype = t._indices().dtype # supports uncoalesced COO tensors
|
||||
elif t.layout is torch.sparse_csr:
|
||||
assert sparse_dim == 2
|
||||
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
|
||||
elif sparsity.layout is torch.sparse_csc:
|
||||
assert sparse_dim == 2 and blocksize is None
|
||||
idx_dtype = t.col_indices().dtype
|
||||
elif t.layout is torch.sparse_csc:
|
||||
assert sparse_dim == 2
|
||||
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
|
||||
idx_dtype = t.row_indices().dtype
|
||||
else:
|
||||
assert sparse_dim == 2 and blocksize is not None
|
||||
if sparsity.layout is torch.sparse_bsr:
|
||||
assert sparse_dim == 2
|
||||
blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3]
|
||||
if t.layout is torch.sparse_bsr:
|
||||
i, j = batch_dim, batch_dim + 1
|
||||
idx_dtype = t.col_indices().dtype
|
||||
else:
|
||||
assert sparsity.layout is torch.sparse_bsc
|
||||
assert t.layout is torch.sparse_bsc
|
||||
j, i = batch_dim, batch_dim + 1
|
||||
idx_dtype = t.row_indices().dtype
|
||||
m, n = blocksize
|
||||
lvls = (
|
||||
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
|
||||
|
@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
|||
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
|
||||
lvls = f"{lvls},{dense}"
|
||||
|
||||
posw = torch.iinfo(sparsity.pos_dtype).bits
|
||||
crdw = torch.iinfo(sparsity.crd_dtype).bits
|
||||
posw = crdw = torch.iinfo(idx_dtype).bits
|
||||
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
|
||||
|
||||
|
||||
|
@ -1043,20 +1026,27 @@ class ContextCache:
|
|||
shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
sparsity: Optional[SparsityMeta] = None,
|
||||
val: Optional[torch.Tensor] = None,
|
||||
mutable: bool = False,
|
||||
):
|
||||
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
||||
stem = "torch.tensor" if mutable else "torch.vtensor"
|
||||
shape_asm = self.format_asm_shape(shape)
|
||||
mlir_dtype = str(self.dtype_to_type(dtype))
|
||||
if sparsity is not None:
|
||||
encoding = sparsity_encoding(shape, sparsity)
|
||||
assert encoding is not None
|
||||
if val is not None and val.layout in [
|
||||
torch.sparse_coo,
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
]:
|
||||
# This is a sparse tensor.
|
||||
encoding = sparsity_encoding(val)
|
||||
return IrType.parse(
|
||||
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
|
||||
context=self._c,
|
||||
)
|
||||
# This is a dense tensor.
|
||||
return IrType.parse(
|
||||
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
||||
)
|
||||
|
@ -1065,21 +1055,17 @@ class ContextCache:
|
|||
try:
|
||||
tensor_meta = node.meta.get("tensor_meta")
|
||||
val = node.meta.get("val")
|
||||
sparsity = node.meta.get("sparsity", None)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||
)
|
||||
return self.value_info_to_type(
|
||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable)
|
||||
|
||||
def value_info_to_type(
|
||||
self,
|
||||
val,
|
||||
*,
|
||||
tensor_meta: Optional[TensorMetadata] = None,
|
||||
sparsity=None,
|
||||
mutable: bool = False,
|
||||
):
|
||||
if tensor_meta is not None:
|
||||
|
@ -1097,14 +1083,14 @@ class ContextCache:
|
|||
)
|
||||
else:
|
||||
return self.tensor_metadata_to_type(
|
||||
tensor_meta, sparsity=sparsity, mutable=mutable
|
||||
tensor_meta, val=val, mutable=mutable
|
||||
)
|
||||
elif val is not None:
|
||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||
# tensor_meta
|
||||
if isinstance(val, TorchFakeTensor):
|
||||
return self.get_vtensor_type(
|
||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
||||
val.size(), val.dtype, val=val, mutable=mutable
|
||||
)
|
||||
elif isinstance(val, list) and all(
|
||||
isinstance(x, TorchFakeTensor) for x in val
|
||||
|
@ -1126,19 +1112,17 @@ class ContextCache:
|
|||
self,
|
||||
tm: TensorMetadata,
|
||||
*,
|
||||
sparsity: Optional[SparsityMeta] = None,
|
||||
val: Optional[torch.Tensor] = None,
|
||||
mutable: bool = False,
|
||||
) -> IrType:
|
||||
tm_shape = tuple(
|
||||
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
||||
)
|
||||
|
||||
key = (tm_shape, tm.dtype, sparsity, mutable)
|
||||
key = (tm_shape, tm.dtype, val, mutable)
|
||||
t = self._tensor_metadata_cache.get(key)
|
||||
if t is None:
|
||||
t = self.get_vtensor_type(
|
||||
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
|
||||
)
|
||||
t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable)
|
||||
self._tensor_metadata_cache[key] = t
|
||||
return t
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
config.unsupported = True
|
||||
|
||||
try:
|
||||
import torch
|
||||
if "2.5.0" <= str(torch.__version__):
|
||||
print("Enabling sparsity propagation tests")
|
||||
config.unsupported = False
|
||||
|
||||
except ModuleNotFoundError:
|
||||
...
|
|
@ -8,13 +8,11 @@
|
|||
from typing import Any, Callable, Optional, Tuple, Dict
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir.extras.fx_importer import SparsityMeta
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
@ -23,139 +21,15 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
|||
)
|
||||
|
||||
|
||||
# All sparse layouts currently supported in torch.sparse.
|
||||
SPARSE_LAYOUTS = [
|
||||
torch.sparse_coo,
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
]
|
||||
|
||||
|
||||
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
||||
"""
|
||||
Returns a meta data tuple for the given sparse tensor.
|
||||
|
||||
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
|
||||
"""
|
||||
sparse_dim = a.sparse_dim()
|
||||
dense_dim = a.dense_dim()
|
||||
batch_dim = a.ndim - dense_dim - sparse_dim
|
||||
blocksize = None
|
||||
if a.layout is torch.sparse_coo:
|
||||
return SparsityMeta(
|
||||
a.layout,
|
||||
batch_dim,
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
blocksize,
|
||||
a._indices().dtype,
|
||||
a._indices().dtype,
|
||||
)
|
||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
||||
if a.layout is torch.sparse_bsr:
|
||||
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
||||
return SparsityMeta(
|
||||
a.layout,
|
||||
batch_dim,
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
blocksize,
|
||||
a.crow_indices().dtype,
|
||||
a.col_indices().dtype,
|
||||
)
|
||||
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
||||
if a.layout is torch.sparse_bsc:
|
||||
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
||||
return SparsityMeta(
|
||||
a.layout,
|
||||
batch_dim,
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
blocksize,
|
||||
a.ccol_indices().dtype,
|
||||
a.row_indices().dtype,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported sparse layout for {a}")
|
||||
|
||||
|
||||
def sparse_export(
|
||||
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> torch.export.ExportedProgram:
|
||||
"""
|
||||
This is a ***temporary*** wrapper around `torch.export.export`
|
||||
that eventually should be removed and simply replaced by the
|
||||
standard API for exporting traced graphs.
|
||||
|
||||
But until issue
|
||||
|
||||
https://github.com/pytorch/pytorch/pull/117907
|
||||
|
||||
is addressed, this wrapper provides support for the sparse
|
||||
tensor types by first converting all operands to dense tensors,
|
||||
building the traced graph as for the dense case, then annotating
|
||||
sparse parameters with their actual sparse layout attributes,
|
||||
followed by some simple propagation rules. This temporary solution
|
||||
accelerates testing torch-mlir with PyTorch sparse tensors until
|
||||
the issue is resolved upstream.
|
||||
"""
|
||||
# Convert all arguments to dense.
|
||||
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
|
||||
mask = [a.layout in SPARSE_LAYOUTS for a in args]
|
||||
# Build the regular FX traced graph with only dense arguments
|
||||
# (the current version would crash otherwise, see issue above).
|
||||
prog = torch.export.export(f, dargs, kwargs)
|
||||
decomposition_table = get_decomposition_table()
|
||||
if decomposition_table:
|
||||
prog = prog.run_decompositions(decomposition_table)
|
||||
# Annotate sparse arguments in the graph and apply some very
|
||||
# basic propagation rules for sparsity.
|
||||
specs = prog.graph_signature.input_specs
|
||||
alen = len(specs)
|
||||
k = 0
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
# Argument.
|
||||
spec = specs[i]
|
||||
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
|
||||
if mask[k]:
|
||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||
k = k + 1
|
||||
elif node.op == "call_function":
|
||||
opname = node.target._schema.name.split("::")[1]
|
||||
# Zero preserving elt-wise unary op.
|
||||
if opname in {"abs", "neg", "relu", "sin"}:
|
||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif opname == "_to_sparse" or opname == "to_sparse":
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
# TODO: Uncomment this to hack sparsity into the network.
|
||||
# elif opname == "_to_dense" or opname == "to_dense":
|
||||
# # hack (assumes we never really want the to_dense for now)
|
||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
|
||||
)
|
||||
return prog
|
||||
|
||||
|
||||
def export_and_import(f, *args, **kwargs):
|
||||
"""This method implements Stella's importer, stripped down to essentials."""
|
||||
"""A FX graph importer, stripped down to essentials."""
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
fx_importer = FxImporter(context=context)
|
||||
prog = sparse_export(f, args, kwargs)
|
||||
prog = torch.export.export(f, args, kwargs)
|
||||
decomposition_table = get_decomposition_table()
|
||||
if decomposition_table:
|
||||
prog = prog.run_decompositions(decomposition_table)
|
||||
fx_importer.import_frozen_program(prog)
|
||||
return fx_importer.module
|
||||
|
||||
|
@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs):
|
|||
enable_ir_printing=False,
|
||||
)
|
||||
# Compile with reference Linalg backend.
|
||||
# TODO: runtime verification currently fails with 'rank mismatch' on
|
||||
# memref.cast. Need to fix the IR first.
|
||||
# TODO: runtime verification ails with 'rank mismatch' on memref.cast
|
||||
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
|
||||
compiled = backend.compile(module)
|
||||
invoker = backend.load(compiled)
|
||||
|
@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs):
|
|||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
# Prompt test name and torch version (for debugging).
|
||||
print(f"{f.__name__} ({torch.__version__})")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
Loading…
Reference in New Issue