[torch-mlir] provide FX traced graph importer for sparse tensors (#2817)

Note that we are waiting for actual FX traced graph support for sparse
tensors. For details see

https://github.com/pytorch/pytorch/issues/117188

Until then, however, we provide this clever importer that builds the FX
traced graph for for the dense case and then puts a sparse annotation
back on the parameters.

With import test.
rm_obsolete_build_automation
Aart Bik 2024-01-30 21:22:12 -08:00 committed by GitHub
parent 1a7442e0aa
commit 105aad6f57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 166 additions and 10 deletions

View File

@ -207,10 +207,32 @@ SYMBOLIC_OP_TO_TORCH_OP = {
}
"""Check whether an object in our graph is symbolic"""
def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string.
The method currently just supports 2-dim sparse formats. This should be
generalized to the torch.sparse encodings for prefix dense batch dimensions
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
is currently implememented in torch.sparse, this should not a be problem.
"""
# TODO: any rank
if len(shape) != 2:
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
if sparse_layout is torch.sparse_coo:
return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>'
if sparse_layout is torch.sparse_csr:
return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>'
if sparse_layout is torch.sparse_csc:
return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>'
# TODO: block format (derive block size!)
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
def is_symbolic(obj: Any) -> bool:
"""Check whether an object in our graph is symbolic"""
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))
@ -337,7 +359,7 @@ class FxImporter:
) from e
arg_replacements[input_name] = state_value
# Remove any lifted placeholders, replacing their uses with the state
# Remove any lifted placeholders, replacing their uses with the state
# replacement value.
g = prog.graph
for node in g.nodes:
@ -455,17 +477,21 @@ class ContextCache:
"""Return IrType for !torch.vtensor with the given shape and dtype"""
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype):
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None):
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparse_layout is not None:
sparsity = sparsity_encoding(shape, sparse_layout)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
)
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c)
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
try:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparse_layout = node.meta.get("sparsity", None)
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
@ -475,12 +501,12 @@ class ContextCache:
f"Quantized tensor meta data is not supported."
)
else:
return self.tensor_metadata_to_type(tensor_meta)
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type(val.size(), val.dtype)
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
@ -495,15 +521,15 @@ class ContextCache:
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)
def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType:
def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)
key = (tm_shape, tm.dtype)
key = (tm_shape, tm.dtype, sparse_layout)
t = self._tensor_metadata_cache.get(key)
if t is None:
t = self.get_vtensor_type(tm.shape, tm.dtype)
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
self._tensor_metadata_cache[key] = t
return t

View File

@ -0,0 +1,130 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
# All sparse layouts currently supported in torch.sparse.
SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc
]
def sparse_export(f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.
But until issue
https://github.com/pytorch/pytorch/pull/117907
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, and then
annotation sparse parameters with their actual sparse layout
attributes. This temporary solution accelerates testing
torch-mlir with PyTorch sparse tensors until the issue is
resovled.
"""
# Convert all arguments to dense.
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs, constraints=None)
# Annotate sparse arguments in the graph.
alen = len(args)
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder" and i < alen and mask[i]:
node.meta['sparsity'] = args[i].layout
# TODO: annotate inputs to change calling conventions!
return prog
def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
# CHECK: %[[N:.*]] = torch.constant.none
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
# CHECK: }
def test_sparse_sum():
class SumNet(torch.nn.Module):
def __init__(self):
super(SumNet, self).__init__()
def forward(self, x):
return x.sum()
dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_csr()
m = export_and_import(SumNet(), sparse_input)
print(m)
@run
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
# CHECK: }
def test_sparse_SpMM():
class MatMulNet(torch.nn.Module):
def __init__(self):
super(MatMulNet, self).__init__()
def forward(self, x, y):
return torch.matmul(x, y)
dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_coo()
m = export_and_import(MatMulNet(), sparse_input, dense_input)
print(m)