From f72770a725ef07927b9b665843c936dba6ab1121 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 20 Aug 2024 09:56:21 -0700 Subject: [PATCH] [torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648) Now that the PyDev feature request pytorch/pytorch#117188 has been completed, we can remove all the ad-hoc code that propagates sparsity metadata and replace it with the built-int PyDev metadata for sparse tensors. This removes a lot of code and also ensures sparsity is consistent with the torch.sparse package for all cases. --- python/torch_mlir/extras/fx_importer.py | 96 +++++------- .../python/fx_importer/sparsity/lit.local.cfg | 10 ++ .../fx_importer/{ => sparsity}/sparse_test.py | 142 +----------------- 3 files changed, 58 insertions(+), 190 deletions(-) create mode 100644 test/python/fx_importer/sparsity/lit.local.cfg rename test/python/fx_importer/{ => sparsity}/sparse_test.py (82%) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 99c8d3cfd..6f936e50e 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr( ) -@dataclass(frozen=True) -class SparsityMeta: - """ - Class for keeping track of sparsity meta data. - - NOTE: this will be fully replaced by - torch.fx.passes.shape_prop.SparseTensorMetadata - """ - - layout: torch.layout - batch_dim: int - sparse_dim: int - dense_dim: int - blocksize: Optional[Tuple[int, int]] - pos_dtype: torch.dtype - crd_dtype: torch.dtype - - -def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: - """Returns sparse tensor encoding for the given sparse layout as string.""" - assert sparsity is not None +def sparsity_encoding(t: torch.Tensor) -> str: + """Returns sparse tensor encoding for the given tensor as string.""" # Sparse tensors have the form # [ , , ] # which map directly to MLIR types. - batch_dim, sparse_dim, dense_dim = ( - sparsity.batch_dim, - sparsity.sparse_dim, - sparsity.dense_dim, + dim, batch_dim, sparse_dim, dense_dim = ( + t.ndim, + t.ndim - t.sparse_dim() - t.dense_dim(), + t.sparse_dim(), + t.dense_dim(), ) - dim = batch_dim + sparse_dim + dense_dim - assert dim == len(shape) - blocksize = sparsity.blocksize - dims = ",".join(f"d{d}" for d in range(dim)) - if sparsity.layout is torch.sparse_coo: - assert sparse_dim >= 2 and blocksize is None + if t.layout is torch.sparse_coo: + assert sparse_dim >= 2 trail_dim = batch_dim + sparse_dim - 1 coords = ",".join( f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim) ) sep = "," if sparse_dim > 2 else "" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" - elif sparsity.layout is torch.sparse_csr: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t._indices().dtype # supports uncoalesced COO tensors + elif t.layout is torch.sparse_csr: + assert sparse_dim == 2 lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" - elif sparsity.layout is torch.sparse_csc: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t.col_indices().dtype + elif t.layout is torch.sparse_csc: + assert sparse_dim == 2 lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" + idx_dtype = t.row_indices().dtype else: - assert sparse_dim == 2 and blocksize is not None - if sparsity.layout is torch.sparse_bsr: + assert sparse_dim == 2 + blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3] + if t.layout is torch.sparse_bsr: i, j = batch_dim, batch_dim + 1 + idx_dtype = t.col_indices().dtype else: - assert sparsity.layout is torch.sparse_bsc + assert t.layout is torch.sparse_bsc j, i = batch_dim, batch_dim + 1 + idx_dtype = t.row_indices().dtype m, n = blocksize lvls = ( f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," @@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw = torch.iinfo(sparsity.pos_dtype).bits - crdw = torch.iinfo(sparsity.crd_dtype).bits + posw = crdw = torch.iinfo(idx_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" @@ -1043,20 +1026,27 @@ class ContextCache: shape: torch.Size, dtype: torch.dtype, *, - sparsity: Optional[SparsityMeta] = None, + val: Optional[torch.Tensor] = None, mutable: bool = False, ): """Return IrType for !torch.vtensor with the given shape and dtype""" stem = "torch.tensor" if mutable else "torch.vtensor" shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) - if sparsity is not None: - encoding = sparsity_encoding(shape, sparsity) - assert encoding is not None + if val is not None and val.layout in [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + ]: + # This is a sparse tensor. + encoding = sparsity_encoding(val) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", context=self._c, ) + # This is a dense tensor. return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c ) @@ -1065,21 +1055,17 @@ class ContextCache: try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") - sparsity = node.meta.get("sparsity", None) except KeyError as e: raise RuntimeError( f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - return self.value_info_to_type( - val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable - ) + return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable) def value_info_to_type( self, val, *, tensor_meta: Optional[TensorMetadata] = None, - sparsity=None, mutable: bool = False, ): if tensor_meta is not None: @@ -1097,14 +1083,14 @@ class ContextCache: ) else: return self.tensor_metadata_to_type( - tensor_meta, sparsity=sparsity, mutable=mutable + tensor_meta, val=val, mutable=mutable ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, sparsity=sparsity, mutable=mutable + val.size(), val.dtype, val=val, mutable=mutable ) elif isinstance(val, list) and all( isinstance(x, TorchFakeTensor) for x in val @@ -1126,19 +1112,17 @@ class ContextCache: self, tm: TensorMetadata, *, - sparsity: Optional[SparsityMeta] = None, + val: Optional[torch.Tensor] = None, mutable: bool = False, ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype, sparsity, mutable) + key = (tm_shape, tm.dtype, val, mutable) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type( - tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable - ) + t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable) self._tensor_metadata_cache[key] = t return t diff --git a/test/python/fx_importer/sparsity/lit.local.cfg b/test/python/fx_importer/sparsity/lit.local.cfg new file mode 100644 index 000000000..274898b14 --- /dev/null +++ b/test/python/fx_importer/sparsity/lit.local.cfg @@ -0,0 +1,10 @@ +config.unsupported = True + +try: + import torch + if "2.5.0" <= str(torch.__version__): + print("Enabling sparsity propagation tests") + config.unsupported = False + +except ModuleNotFoundError: + ... diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparsity/sparse_test.py similarity index 82% rename from test/python/fx_importer/sparse_test.py rename to test/python/fx_importer/sparsity/sparse_test.py index 089a5eabb..56f9e9ec7 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparsity/sparse_test.py @@ -8,13 +8,11 @@ from typing import Any, Callable, Optional, Tuple, Dict import torch -import torch.export import torch.nn as nn import numpy as np from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_importer import FxImporter -from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -23,139 +21,15 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( ) -# All sparse layouts currently supported in torch.sparse. -SPARSE_LAYOUTS = [ - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_csc, - torch.sparse_bsr, - torch.sparse_bsc, -] - - -def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """ - Returns a meta data tuple for the given sparse tensor. - - NOTE: this will be fully replaced by fx graph SparseTensorMetadata - """ - sparse_dim = a.sparse_dim() - dense_dim = a.dense_dim() - batch_dim = a.ndim - dense_dim - sparse_dim - blocksize = None - if a.layout is torch.sparse_coo: - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a._indices().dtype, - a._indices().dtype, - ) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: - if a.layout is torch.sparse_bsr: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.crow_indices().dtype, - a.col_indices().dtype, - ) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: - if a.layout is torch.sparse_bsc: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.ccol_indices().dtype, - a.row_indices().dtype, - ) - else: - raise RuntimeError(f"Unsupported sparse layout for {a}") - - -def sparse_export( - f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None -) -> torch.export.ExportedProgram: - """ - This is a ***temporary*** wrapper around `torch.export.export` - that eventually should be removed and simply replaced by the - standard API for exporting traced graphs. - - But until issue - - https://github.com/pytorch/pytorch/pull/117907 - - is addressed, this wrapper provides support for the sparse - tensor types by first converting all operands to dense tensors, - building the traced graph as for the dense case, then annotating - sparse parameters with their actual sparse layout attributes, - followed by some simple propagation rules. This temporary solution - accelerates testing torch-mlir with PyTorch sparse tensors until - the issue is resolved upstream. - """ - # Convert all arguments to dense. - dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) - mask = [a.layout in SPARSE_LAYOUTS for a in args] - # Build the regular FX traced graph with only dense arguments - # (the current version would crash otherwise, see issue above). - prog = torch.export.export(f, dargs, kwargs) - decomposition_table = get_decomposition_table() - if decomposition_table: - prog = prog.run_decompositions(decomposition_table) - # Annotate sparse arguments in the graph and apply some very - # basic propagation rules for sparsity. - specs = prog.graph_signature.input_specs - alen = len(specs) - k = 0 - for i, node in enumerate(prog.graph.nodes): - if node.op == "placeholder": - # Argument. - spec = specs[i] - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - if mask[k]: - node.meta["sparsity"] = sparse_metadata(args[k]) - k = k + 1 - elif node.op == "call_function": - opname = node.target._schema.name.split("::")[1] - # Zero preserving elt-wise unary op. - if opname in {"abs", "neg", "relu", "sin"}: - node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "_to_sparse" or opname == "to_sparse": - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 - ) - # TODO: Uncomment this to hack sparsity into the network. - # elif opname == "_to_dense" or opname == "to_dense": - # # hack (assumes we never really want the to_dense for now) - # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "select" and node.args[0].meta.get("sparsity", None): - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 - ) - elif opname == "stack" and node.args[0][0].meta.get("sparsity", None): - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64 - ) - return prog - - def export_and_import(f, *args, **kwargs): - """This method implements Stella's importer, stripped down to essentials.""" + """A FX graph importer, stripped down to essentials.""" context = ir.Context() torch_d.register_dialect(context) fx_importer = FxImporter(context=context) - prog = sparse_export(f, args, kwargs) + prog = torch.export.export(f, args, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) fx_importer.import_frozen_program(prog) return fx_importer.module @@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs): enable_ir_printing=False, ) # Compile with reference Linalg backend. - # TODO: runtime verification currently fails with 'rank mismatch' on - # memref.cast. Need to fix the IR first. + # TODO: runtime verification ails with 'rank mismatch' on memref.cast backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False) compiled = backend.compile(module) invoker = backend.load(compiled) @@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs): def run(f): - print(f"{f.__name__}") + # Prompt test name and torch version (for debugging). + print(f"{f.__name__} ({torch.__version__})") print("-" * len(f.__name__)) f() print()