mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648)
Now that the PyDev feature request pytorch/pytorch#117188 has been completed, we can remove all the ad-hoc code that propagates sparsity metadata and replace it with the built-int PyDev metadata for sparse tensors. This removes a lot of code and also ensures sparsity is consistent with the torch.sparse package for all cases.pull/3614/merge
parent
0a86deb59a
commit
f72770a725
|
@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
def sparsity_encoding(t: torch.Tensor) -> str:
|
||||||
class SparsityMeta:
|
"""Returns sparse tensor encoding for the given tensor as string."""
|
||||||
"""
|
|
||||||
Class for keeping track of sparsity meta data.
|
|
||||||
|
|
||||||
NOTE: this will be fully replaced by
|
|
||||||
torch.fx.passes.shape_prop.SparseTensorMetadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
layout: torch.layout
|
|
||||||
batch_dim: int
|
|
||||||
sparse_dim: int
|
|
||||||
dense_dim: int
|
|
||||||
blocksize: Optional[Tuple[int, int]]
|
|
||||||
pos_dtype: torch.dtype
|
|
||||||
crd_dtype: torch.dtype
|
|
||||||
|
|
||||||
|
|
||||||
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
|
||||||
"""Returns sparse tensor encoding for the given sparse layout as string."""
|
|
||||||
assert sparsity is not None
|
|
||||||
|
|
||||||
# Sparse tensors have the form
|
# Sparse tensors have the form
|
||||||
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
|
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
|
||||||
# which map directly to MLIR types.
|
# which map directly to MLIR types.
|
||||||
batch_dim, sparse_dim, dense_dim = (
|
dim, batch_dim, sparse_dim, dense_dim = (
|
||||||
sparsity.batch_dim,
|
t.ndim,
|
||||||
sparsity.sparse_dim,
|
t.ndim - t.sparse_dim() - t.dense_dim(),
|
||||||
sparsity.dense_dim,
|
t.sparse_dim(),
|
||||||
|
t.dense_dim(),
|
||||||
)
|
)
|
||||||
dim = batch_dim + sparse_dim + dense_dim
|
|
||||||
assert dim == len(shape)
|
|
||||||
blocksize = sparsity.blocksize
|
|
||||||
|
|
||||||
dims = ",".join(f"d{d}" for d in range(dim))
|
dims = ",".join(f"d{d}" for d in range(dim))
|
||||||
|
|
||||||
if sparsity.layout is torch.sparse_coo:
|
if t.layout is torch.sparse_coo:
|
||||||
assert sparse_dim >= 2 and blocksize is None
|
assert sparse_dim >= 2
|
||||||
trail_dim = batch_dim + sparse_dim - 1
|
trail_dim = batch_dim + sparse_dim - 1
|
||||||
coords = ",".join(
|
coords = ",".join(
|
||||||
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
|
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
|
||||||
)
|
)
|
||||||
sep = "," if sparse_dim > 2 else ""
|
sep = "," if sparse_dim > 2 else ""
|
||||||
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
|
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
|
||||||
elif sparsity.layout is torch.sparse_csr:
|
idx_dtype = t._indices().dtype # supports uncoalesced COO tensors
|
||||||
assert sparse_dim == 2 and blocksize is None
|
elif t.layout is torch.sparse_csr:
|
||||||
|
assert sparse_dim == 2
|
||||||
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
|
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
|
||||||
elif sparsity.layout is torch.sparse_csc:
|
idx_dtype = t.col_indices().dtype
|
||||||
assert sparse_dim == 2 and blocksize is None
|
elif t.layout is torch.sparse_csc:
|
||||||
|
assert sparse_dim == 2
|
||||||
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
|
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
|
||||||
|
idx_dtype = t.row_indices().dtype
|
||||||
else:
|
else:
|
||||||
assert sparse_dim == 2 and blocksize is not None
|
assert sparse_dim == 2
|
||||||
if sparsity.layout is torch.sparse_bsr:
|
blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3]
|
||||||
|
if t.layout is torch.sparse_bsr:
|
||||||
i, j = batch_dim, batch_dim + 1
|
i, j = batch_dim, batch_dim + 1
|
||||||
|
idx_dtype = t.col_indices().dtype
|
||||||
else:
|
else:
|
||||||
assert sparsity.layout is torch.sparse_bsc
|
assert t.layout is torch.sparse_bsc
|
||||||
j, i = batch_dim, batch_dim + 1
|
j, i = batch_dim, batch_dim + 1
|
||||||
|
idx_dtype = t.row_indices().dtype
|
||||||
m, n = blocksize
|
m, n = blocksize
|
||||||
lvls = (
|
lvls = (
|
||||||
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
|
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
|
||||||
|
@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
||||||
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
|
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
|
||||||
lvls = f"{lvls},{dense}"
|
lvls = f"{lvls},{dense}"
|
||||||
|
|
||||||
posw = torch.iinfo(sparsity.pos_dtype).bits
|
posw = crdw = torch.iinfo(idx_dtype).bits
|
||||||
crdw = torch.iinfo(sparsity.crd_dtype).bits
|
|
||||||
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
|
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1043,20 +1026,27 @@ class ContextCache:
|
||||||
shape: torch.Size,
|
shape: torch.Size,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
*,
|
*,
|
||||||
sparsity: Optional[SparsityMeta] = None,
|
val: Optional[torch.Tensor] = None,
|
||||||
mutable: bool = False,
|
mutable: bool = False,
|
||||||
):
|
):
|
||||||
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
||||||
stem = "torch.tensor" if mutable else "torch.vtensor"
|
stem = "torch.tensor" if mutable else "torch.vtensor"
|
||||||
shape_asm = self.format_asm_shape(shape)
|
shape_asm = self.format_asm_shape(shape)
|
||||||
mlir_dtype = str(self.dtype_to_type(dtype))
|
mlir_dtype = str(self.dtype_to_type(dtype))
|
||||||
if sparsity is not None:
|
if val is not None and val.layout in [
|
||||||
encoding = sparsity_encoding(shape, sparsity)
|
torch.sparse_coo,
|
||||||
assert encoding is not None
|
torch.sparse_csr,
|
||||||
|
torch.sparse_csc,
|
||||||
|
torch.sparse_bsr,
|
||||||
|
torch.sparse_bsc,
|
||||||
|
]:
|
||||||
|
# This is a sparse tensor.
|
||||||
|
encoding = sparsity_encoding(val)
|
||||||
return IrType.parse(
|
return IrType.parse(
|
||||||
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
|
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
|
||||||
context=self._c,
|
context=self._c,
|
||||||
)
|
)
|
||||||
|
# This is a dense tensor.
|
||||||
return IrType.parse(
|
return IrType.parse(
|
||||||
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
||||||
)
|
)
|
||||||
|
@ -1065,21 +1055,17 @@ class ContextCache:
|
||||||
try:
|
try:
|
||||||
tensor_meta = node.meta.get("tensor_meta")
|
tensor_meta = node.meta.get("tensor_meta")
|
||||||
val = node.meta.get("val")
|
val = node.meta.get("val")
|
||||||
sparsity = node.meta.get("sparsity", None)
|
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||||
)
|
)
|
||||||
return self.value_info_to_type(
|
return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable)
|
||||||
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
|
|
||||||
)
|
|
||||||
|
|
||||||
def value_info_to_type(
|
def value_info_to_type(
|
||||||
self,
|
self,
|
||||||
val,
|
val,
|
||||||
*,
|
*,
|
||||||
tensor_meta: Optional[TensorMetadata] = None,
|
tensor_meta: Optional[TensorMetadata] = None,
|
||||||
sparsity=None,
|
|
||||||
mutable: bool = False,
|
mutable: bool = False,
|
||||||
):
|
):
|
||||||
if tensor_meta is not None:
|
if tensor_meta is not None:
|
||||||
|
@ -1097,14 +1083,14 @@ class ContextCache:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.tensor_metadata_to_type(
|
return self.tensor_metadata_to_type(
|
||||||
tensor_meta, sparsity=sparsity, mutable=mutable
|
tensor_meta, val=val, mutable=mutable
|
||||||
)
|
)
|
||||||
elif val is not None:
|
elif val is not None:
|
||||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||||
# tensor_meta
|
# tensor_meta
|
||||||
if isinstance(val, TorchFakeTensor):
|
if isinstance(val, TorchFakeTensor):
|
||||||
return self.get_vtensor_type(
|
return self.get_vtensor_type(
|
||||||
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
|
val.size(), val.dtype, val=val, mutable=mutable
|
||||||
)
|
)
|
||||||
elif isinstance(val, list) and all(
|
elif isinstance(val, list) and all(
|
||||||
isinstance(x, TorchFakeTensor) for x in val
|
isinstance(x, TorchFakeTensor) for x in val
|
||||||
|
@ -1126,19 +1112,17 @@ class ContextCache:
|
||||||
self,
|
self,
|
||||||
tm: TensorMetadata,
|
tm: TensorMetadata,
|
||||||
*,
|
*,
|
||||||
sparsity: Optional[SparsityMeta] = None,
|
val: Optional[torch.Tensor] = None,
|
||||||
mutable: bool = False,
|
mutable: bool = False,
|
||||||
) -> IrType:
|
) -> IrType:
|
||||||
tm_shape = tuple(
|
tm_shape = tuple(
|
||||||
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
key = (tm_shape, tm.dtype, sparsity, mutable)
|
key = (tm_shape, tm.dtype, val, mutable)
|
||||||
t = self._tensor_metadata_cache.get(key)
|
t = self._tensor_metadata_cache.get(key)
|
||||||
if t is None:
|
if t is None:
|
||||||
t = self.get_vtensor_type(
|
t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable)
|
||||||
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
|
|
||||||
)
|
|
||||||
self._tensor_metadata_cache[key] = t
|
self._tensor_metadata_cache[key] = t
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
config.unsupported = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if "2.5.0" <= str(torch.__version__):
|
||||||
|
print("Enabling sparsity propagation tests")
|
||||||
|
config.unsupported = False
|
||||||
|
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
...
|
|
@ -8,13 +8,11 @@
|
||||||
from typing import Any, Callable, Optional, Tuple, Dict
|
from typing import Any, Callable, Optional, Tuple, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.export
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||||
from torch_mlir.extras.fx_importer import FxImporter
|
from torch_mlir.extras.fx_importer import FxImporter
|
||||||
from torch_mlir.extras.fx_importer import SparsityMeta
|
|
||||||
from torch_mlir import ir
|
from torch_mlir import ir
|
||||||
from torch_mlir.dialects import torch as torch_d
|
from torch_mlir.dialects import torch as torch_d
|
||||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
@ -23,139 +21,15 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# All sparse layouts currently supported in torch.sparse.
|
|
||||||
SPARSE_LAYOUTS = [
|
|
||||||
torch.sparse_coo,
|
|
||||||
torch.sparse_csr,
|
|
||||||
torch.sparse_csc,
|
|
||||||
torch.sparse_bsr,
|
|
||||||
torch.sparse_bsc,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
|
||||||
"""
|
|
||||||
Returns a meta data tuple for the given sparse tensor.
|
|
||||||
|
|
||||||
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
|
|
||||||
"""
|
|
||||||
sparse_dim = a.sparse_dim()
|
|
||||||
dense_dim = a.dense_dim()
|
|
||||||
batch_dim = a.ndim - dense_dim - sparse_dim
|
|
||||||
blocksize = None
|
|
||||||
if a.layout is torch.sparse_coo:
|
|
||||||
return SparsityMeta(
|
|
||||||
a.layout,
|
|
||||||
batch_dim,
|
|
||||||
sparse_dim,
|
|
||||||
dense_dim,
|
|
||||||
blocksize,
|
|
||||||
a._indices().dtype,
|
|
||||||
a._indices().dtype,
|
|
||||||
)
|
|
||||||
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
|
||||||
if a.layout is torch.sparse_bsr:
|
|
||||||
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
|
||||||
return SparsityMeta(
|
|
||||||
a.layout,
|
|
||||||
batch_dim,
|
|
||||||
sparse_dim,
|
|
||||||
dense_dim,
|
|
||||||
blocksize,
|
|
||||||
a.crow_indices().dtype,
|
|
||||||
a.col_indices().dtype,
|
|
||||||
)
|
|
||||||
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
|
||||||
if a.layout is torch.sparse_bsc:
|
|
||||||
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
|
||||||
return SparsityMeta(
|
|
||||||
a.layout,
|
|
||||||
batch_dim,
|
|
||||||
sparse_dim,
|
|
||||||
dense_dim,
|
|
||||||
blocksize,
|
|
||||||
a.ccol_indices().dtype,
|
|
||||||
a.row_indices().dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported sparse layout for {a}")
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_export(
|
|
||||||
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
|
||||||
) -> torch.export.ExportedProgram:
|
|
||||||
"""
|
|
||||||
This is a ***temporary*** wrapper around `torch.export.export`
|
|
||||||
that eventually should be removed and simply replaced by the
|
|
||||||
standard API for exporting traced graphs.
|
|
||||||
|
|
||||||
But until issue
|
|
||||||
|
|
||||||
https://github.com/pytorch/pytorch/pull/117907
|
|
||||||
|
|
||||||
is addressed, this wrapper provides support for the sparse
|
|
||||||
tensor types by first converting all operands to dense tensors,
|
|
||||||
building the traced graph as for the dense case, then annotating
|
|
||||||
sparse parameters with their actual sparse layout attributes,
|
|
||||||
followed by some simple propagation rules. This temporary solution
|
|
||||||
accelerates testing torch-mlir with PyTorch sparse tensors until
|
|
||||||
the issue is resolved upstream.
|
|
||||||
"""
|
|
||||||
# Convert all arguments to dense.
|
|
||||||
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
|
|
||||||
mask = [a.layout in SPARSE_LAYOUTS for a in args]
|
|
||||||
# Build the regular FX traced graph with only dense arguments
|
|
||||||
# (the current version would crash otherwise, see issue above).
|
|
||||||
prog = torch.export.export(f, dargs, kwargs)
|
|
||||||
decomposition_table = get_decomposition_table()
|
|
||||||
if decomposition_table:
|
|
||||||
prog = prog.run_decompositions(decomposition_table)
|
|
||||||
# Annotate sparse arguments in the graph and apply some very
|
|
||||||
# basic propagation rules for sparsity.
|
|
||||||
specs = prog.graph_signature.input_specs
|
|
||||||
alen = len(specs)
|
|
||||||
k = 0
|
|
||||||
for i, node in enumerate(prog.graph.nodes):
|
|
||||||
if node.op == "placeholder":
|
|
||||||
# Argument.
|
|
||||||
spec = specs[i]
|
|
||||||
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
|
|
||||||
if mask[k]:
|
|
||||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
|
||||||
k = k + 1
|
|
||||||
elif node.op == "call_function":
|
|
||||||
opname = node.target._schema.name.split("::")[1]
|
|
||||||
# Zero preserving elt-wise unary op.
|
|
||||||
if opname in {"abs", "neg", "relu", "sin"}:
|
|
||||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
|
||||||
elif opname == "_to_sparse" or opname == "to_sparse":
|
|
||||||
dim = len(node.meta.get("val").shape)
|
|
||||||
node.meta["sparsity"] = SparsityMeta(
|
|
||||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
|
||||||
)
|
|
||||||
# TODO: Uncomment this to hack sparsity into the network.
|
|
||||||
# elif opname == "_to_dense" or opname == "to_dense":
|
|
||||||
# # hack (assumes we never really want the to_dense for now)
|
|
||||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
|
||||||
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
|
||||||
dim = len(node.meta.get("val").shape)
|
|
||||||
node.meta["sparsity"] = SparsityMeta(
|
|
||||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
|
||||||
)
|
|
||||||
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
|
|
||||||
dim = len(node.meta.get("val").shape)
|
|
||||||
node.meta["sparsity"] = SparsityMeta(
|
|
||||||
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
|
|
||||||
)
|
|
||||||
return prog
|
|
||||||
|
|
||||||
|
|
||||||
def export_and_import(f, *args, **kwargs):
|
def export_and_import(f, *args, **kwargs):
|
||||||
"""This method implements Stella's importer, stripped down to essentials."""
|
"""A FX graph importer, stripped down to essentials."""
|
||||||
context = ir.Context()
|
context = ir.Context()
|
||||||
torch_d.register_dialect(context)
|
torch_d.register_dialect(context)
|
||||||
fx_importer = FxImporter(context=context)
|
fx_importer = FxImporter(context=context)
|
||||||
prog = sparse_export(f, args, kwargs)
|
prog = torch.export.export(f, args, kwargs)
|
||||||
|
decomposition_table = get_decomposition_table()
|
||||||
|
if decomposition_table:
|
||||||
|
prog = prog.run_decompositions(decomposition_table)
|
||||||
fx_importer.import_frozen_program(prog)
|
fx_importer.import_frozen_program(prog)
|
||||||
return fx_importer.module
|
return fx_importer.module
|
||||||
|
|
||||||
|
@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs):
|
||||||
enable_ir_printing=False,
|
enable_ir_printing=False,
|
||||||
)
|
)
|
||||||
# Compile with reference Linalg backend.
|
# Compile with reference Linalg backend.
|
||||||
# TODO: runtime verification currently fails with 'rank mismatch' on
|
# TODO: runtime verification ails with 'rank mismatch' on memref.cast
|
||||||
# memref.cast. Need to fix the IR first.
|
|
||||||
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
|
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
|
||||||
compiled = backend.compile(module)
|
compiled = backend.compile(module)
|
||||||
invoker = backend.load(compiled)
|
invoker = backend.load(compiled)
|
||||||
|
@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
def run(f):
|
||||||
print(f"{f.__name__}")
|
# Prompt test name and torch version (for debugging).
|
||||||
|
print(f"{f.__name__} ({torch.__version__})")
|
||||||
print("-" * len(f.__name__))
|
print("-" * len(f.__name__))
|
||||||
f()
|
f()
|
||||||
print()
|
print()
|
Loading…
Reference in New Issue