mirror of https://github.com/llvm/torch-mlir
[torch-mlir] provide FX traced graph importer for sparse tensors (#2817)
Note that we are waiting for actual FX traced graph support for sparse tensors. For details see https://github.com/pytorch/pytorch/issues/117188 Until then, however, we provide this clever importer that builds the FX traced graph for for the dense case and then puts a sparse annotation back on the parameters. With import test.rm_obsolete_build_automation
parent
1a7442e0aa
commit
105aad6f57
|
@ -207,10 +207,32 @@ SYMBOLIC_OP_TO_TORCH_OP = {
|
|||
}
|
||||
|
||||
|
||||
"""Check whether an object in our graph is symbolic"""
|
||||
def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str:
|
||||
"""Returns sparse tensor encoding for the given sparse layout as string.
|
||||
|
||||
The method currently just supports 2-dim sparse formats. This should be
|
||||
generalized to the torch.sparse encodings for prefix dense batch dimensions
|
||||
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
|
||||
is currently implememented in torch.sparse, this should not a be problem.
|
||||
"""
|
||||
|
||||
# TODO: any rank
|
||||
if len(shape) != 2:
|
||||
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
|
||||
|
||||
if sparse_layout is torch.sparse_coo:
|
||||
return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>'
|
||||
if sparse_layout is torch.sparse_csr:
|
||||
return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>'
|
||||
if sparse_layout is torch.sparse_csc:
|
||||
return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>'
|
||||
# TODO: block format (derive block size!)
|
||||
|
||||
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
|
||||
|
||||
|
||||
def is_symbolic(obj: Any) -> bool:
|
||||
"""Check whether an object in our graph is symbolic"""
|
||||
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))
|
||||
|
||||
|
||||
|
@ -337,7 +359,7 @@ class FxImporter:
|
|||
) from e
|
||||
arg_replacements[input_name] = state_value
|
||||
|
||||
# Remove any lifted placeholders, replacing their uses with the state
|
||||
# Remove any lifted placeholders, replacing their uses with the state
|
||||
# replacement value.
|
||||
g = prog.graph
|
||||
for node in g.nodes:
|
||||
|
@ -455,17 +477,21 @@ class ContextCache:
|
|||
|
||||
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
||||
|
||||
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype):
|
||||
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None):
|
||||
shape_asm = self.format_asm_shape(shape)
|
||||
mlir_dtype = str(self.dtype_to_type(dtype))
|
||||
if sparse_layout is not None:
|
||||
sparsity = sparsity_encoding(shape, sparse_layout)
|
||||
return IrType.parse(
|
||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c)
|
||||
return IrType.parse(
|
||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
||||
)
|
||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c)
|
||||
|
||||
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
|
||||
try:
|
||||
tensor_meta = node.meta.get("tensor_meta")
|
||||
val = node.meta.get("val")
|
||||
sparse_layout = node.meta.get("sparsity", None)
|
||||
if tensor_meta is not None:
|
||||
assert isinstance(tensor_meta, TensorMetadata)
|
||||
# Quantized tensor meta data is not preserved in our lowering,
|
||||
|
@ -475,12 +501,12 @@ class ContextCache:
|
|||
f"Quantized tensor meta data is not supported."
|
||||
)
|
||||
else:
|
||||
return self.tensor_metadata_to_type(tensor_meta)
|
||||
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
|
||||
elif val is not None:
|
||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||
# tensor_meta
|
||||
if isinstance(val, TorchFakeTensor):
|
||||
return self.get_vtensor_type(val.size(), val.dtype)
|
||||
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)
|
||||
|
||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||
if t is not None:
|
||||
|
@ -495,15 +521,15 @@ class ContextCache:
|
|||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||
)
|
||||
|
||||
def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType:
|
||||
def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType:
|
||||
tm_shape = tuple(
|
||||
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
||||
)
|
||||
|
||||
key = (tm_shape, tm.dtype)
|
||||
key = (tm_shape, tm.dtype, sparse_layout)
|
||||
t = self._tensor_metadata_cache.get(key)
|
||||
if t is None:
|
||||
t = self.get_vtensor_type(tm.shape, tm.dtype)
|
||||
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
|
||||
self._tensor_metadata_cache[key] = t
|
||||
return t
|
||||
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
|
||||
|
||||
# All sparse layouts currently supported in torch.sparse.
|
||||
SPARSE_LAYOUTS = [
|
||||
torch.sparse_coo,
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc
|
||||
]
|
||||
|
||||
|
||||
def sparse_export(f: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
|
||||
"""
|
||||
This is a ***temporary*** wrapper around `torch.export.export`
|
||||
that eventually should be removed and simply replaced by the
|
||||
standard API for exporting traced graphs.
|
||||
|
||||
But until issue
|
||||
|
||||
https://github.com/pytorch/pytorch/pull/117907
|
||||
|
||||
is addressed, this wrapper provides support for the sparse
|
||||
tensor types by first converting all operands to dense tensors,
|
||||
building the traced graph as for the dense case, and then
|
||||
annotation sparse parameters with their actual sparse layout
|
||||
attributes. This temporary solution accelerates testing
|
||||
torch-mlir with PyTorch sparse tensors until the issue is
|
||||
resovled.
|
||||
"""
|
||||
# Convert all arguments to dense.
|
||||
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
|
||||
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
|
||||
# Build the regular FX traced graph with only dense arguments
|
||||
# (the current version would crash otherwise, see issue above).
|
||||
prog = torch.export.export(f, dargs, kwargs, constraints=None)
|
||||
# Annotate sparse arguments in the graph.
|
||||
alen = len(args)
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
if node.op == "placeholder" and i < alen and mask[i]:
|
||||
node.meta['sparsity'] = args[i].layout
|
||||
# TODO: annotate inputs to change calling conventions!
|
||||
return prog
|
||||
|
||||
|
||||
def export_and_import(f, *args, **kwargs):
|
||||
"""This method implements Stella's importer, stripped down to essentials."""
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
fx_importer = FxImporter(context=context)
|
||||
prog = sparse_export(f, args, kwargs)
|
||||
fx_importer.import_frozen_exported_program(prog)
|
||||
return fx_importer.module_op
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_sum
|
||||
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
|
||||
# CHECK: %[[N:.*]] = torch.constant.none
|
||||
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
|
||||
# CHECK: }
|
||||
def test_sparse_sum():
|
||||
|
||||
class SumNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SumNet, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
|
||||
dense_input = torch.ones(64, 64)
|
||||
sparse_input = dense_input.to_sparse_csr()
|
||||
m = export_and_import(SumNet(), sparse_input)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sparse_SpMM
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
|
||||
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
|
||||
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
|
||||
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
|
||||
# CHECK: }
|
||||
def test_sparse_SpMM():
|
||||
|
||||
class MatMulNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MatMulNet, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.matmul(x, y)
|
||||
|
||||
|
||||
dense_input = torch.ones(64, 64)
|
||||
sparse_input = dense_input.to_sparse_coo()
|
||||
m = export_and_import(MatMulNet(), sparse_input, dense_input)
|
||||
print(m)
|
Loading…
Reference in New Issue