mirror of https://github.com/llvm/torch-mlir
[torch-mlir] provide FX traced graph importer for sparse tensors (#2817)
Note that we are waiting for actual FX traced graph support for sparse tensors. For details see https://github.com/pytorch/pytorch/issues/117188 Until then, however, we provide this clever importer that builds the FX traced graph for for the dense case and then puts a sparse annotation back on the parameters. With import test.rm_obsolete_build_automation
parent
1a7442e0aa
commit
105aad6f57
|
@ -207,10 +207,32 @@ SYMBOLIC_OP_TO_TORCH_OP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
"""Check whether an object in our graph is symbolic"""
|
def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str:
|
||||||
|
"""Returns sparse tensor encoding for the given sparse layout as string.
|
||||||
|
|
||||||
|
The method currently just supports 2-dim sparse formats. This should be
|
||||||
|
generalized to the torch.sparse encodings for prefix dense batch dimensions
|
||||||
|
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
|
||||||
|
is currently implememented in torch.sparse, this should not a be problem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: any rank
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
|
||||||
|
|
||||||
|
if sparse_layout is torch.sparse_coo:
|
||||||
|
return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>'
|
||||||
|
if sparse_layout is torch.sparse_csr:
|
||||||
|
return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>'
|
||||||
|
if sparse_layout is torch.sparse_csc:
|
||||||
|
return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>'
|
||||||
|
# TODO: block format (derive block size!)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
|
||||||
|
|
||||||
|
|
||||||
def is_symbolic(obj: Any) -> bool:
|
def is_symbolic(obj: Any) -> bool:
|
||||||
|
"""Check whether an object in our graph is symbolic"""
|
||||||
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))
|
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))
|
||||||
|
|
||||||
|
|
||||||
|
@ -455,17 +477,21 @@ class ContextCache:
|
||||||
|
|
||||||
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
||||||
|
|
||||||
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype):
|
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None):
|
||||||
shape_asm = self.format_asm_shape(shape)
|
shape_asm = self.format_asm_shape(shape)
|
||||||
mlir_dtype = str(self.dtype_to_type(dtype))
|
mlir_dtype = str(self.dtype_to_type(dtype))
|
||||||
|
if sparse_layout is not None:
|
||||||
|
sparsity = sparsity_encoding(shape, sparse_layout)
|
||||||
return IrType.parse(
|
return IrType.parse(
|
||||||
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c)
|
||||||
)
|
return IrType.parse(
|
||||||
|
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c)
|
||||||
|
|
||||||
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
|
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
|
||||||
try:
|
try:
|
||||||
tensor_meta = node.meta.get("tensor_meta")
|
tensor_meta = node.meta.get("tensor_meta")
|
||||||
val = node.meta.get("val")
|
val = node.meta.get("val")
|
||||||
|
sparse_layout = node.meta.get("sparsity", None)
|
||||||
if tensor_meta is not None:
|
if tensor_meta is not None:
|
||||||
assert isinstance(tensor_meta, TensorMetadata)
|
assert isinstance(tensor_meta, TensorMetadata)
|
||||||
# Quantized tensor meta data is not preserved in our lowering,
|
# Quantized tensor meta data is not preserved in our lowering,
|
||||||
|
@ -475,12 +501,12 @@ class ContextCache:
|
||||||
f"Quantized tensor meta data is not supported."
|
f"Quantized tensor meta data is not supported."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.tensor_metadata_to_type(tensor_meta)
|
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
|
||||||
elif val is not None:
|
elif val is not None:
|
||||||
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
||||||
# tensor_meta
|
# tensor_meta
|
||||||
if isinstance(val, TorchFakeTensor):
|
if isinstance(val, TorchFakeTensor):
|
||||||
return self.get_vtensor_type(val.size(), val.dtype)
|
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)
|
||||||
|
|
||||||
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
||||||
if t is not None:
|
if t is not None:
|
||||||
|
@ -495,15 +521,15 @@ class ContextCache:
|
||||||
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType:
|
def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType:
|
||||||
tm_shape = tuple(
|
tm_shape = tuple(
|
||||||
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
key = (tm_shape, tm.dtype)
|
key = (tm_shape, tm.dtype, sparse_layout)
|
||||||
t = self._tensor_metadata_cache.get(key)
|
t = self._tensor_metadata_cache.get(key)
|
||||||
if t is None:
|
if t is None:
|
||||||
t = self.get_vtensor_type(tm.shape, tm.dtype)
|
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
|
||||||
self._tensor_metadata_cache[key] = t
|
self._tensor_metadata_cache[key] = t
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.export
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torch_mlir.extras.fx_importer import FxImporter
|
||||||
|
from torch_mlir import ir
|
||||||
|
from torch_mlir.dialects import torch as torch_d
|
||||||
|
|
||||||
|
|
||||||
|
# All sparse layouts currently supported in torch.sparse.
|
||||||
|
SPARSE_LAYOUTS = [
|
||||||
|
torch.sparse_coo,
|
||||||
|
torch.sparse_csr,
|
||||||
|
torch.sparse_csc,
|
||||||
|
torch.sparse_bsr,
|
||||||
|
torch.sparse_bsc
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_export(f: Callable,
|
||||||
|
args: Tuple[Any, ...],
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
|
||||||
|
"""
|
||||||
|
This is a ***temporary*** wrapper around `torch.export.export`
|
||||||
|
that eventually should be removed and simply replaced by the
|
||||||
|
standard API for exporting traced graphs.
|
||||||
|
|
||||||
|
But until issue
|
||||||
|
|
||||||
|
https://github.com/pytorch/pytorch/pull/117907
|
||||||
|
|
||||||
|
is addressed, this wrapper provides support for the sparse
|
||||||
|
tensor types by first converting all operands to dense tensors,
|
||||||
|
building the traced graph as for the dense case, and then
|
||||||
|
annotation sparse parameters with their actual sparse layout
|
||||||
|
attributes. This temporary solution accelerates testing
|
||||||
|
torch-mlir with PyTorch sparse tensors until the issue is
|
||||||
|
resovled.
|
||||||
|
"""
|
||||||
|
# Convert all arguments to dense.
|
||||||
|
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
|
||||||
|
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
|
||||||
|
# Build the regular FX traced graph with only dense arguments
|
||||||
|
# (the current version would crash otherwise, see issue above).
|
||||||
|
prog = torch.export.export(f, dargs, kwargs, constraints=None)
|
||||||
|
# Annotate sparse arguments in the graph.
|
||||||
|
alen = len(args)
|
||||||
|
for i, node in enumerate(prog.graph.nodes):
|
||||||
|
if node.op == "placeholder" and i < alen and mask[i]:
|
||||||
|
node.meta['sparsity'] = args[i].layout
|
||||||
|
# TODO: annotate inputs to change calling conventions!
|
||||||
|
return prog
|
||||||
|
|
||||||
|
|
||||||
|
def export_and_import(f, *args, **kwargs):
|
||||||
|
"""This method implements Stella's importer, stripped down to essentials."""
|
||||||
|
context = ir.Context()
|
||||||
|
torch_d.register_dialect(context)
|
||||||
|
fx_importer = FxImporter(context=context)
|
||||||
|
prog = sparse_export(f, args, kwargs)
|
||||||
|
fx_importer.import_frozen_exported_program(prog)
|
||||||
|
return fx_importer.module_op
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
print(f"{f.__name__}")
|
||||||
|
print("-" * len(f.__name__))
|
||||||
|
f()
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_sparse_sum
|
||||||
|
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||||
|
# CHECK: func.func @main(
|
||||||
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
|
||||||
|
# CHECK: %[[N:.*]] = torch.constant.none
|
||||||
|
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
|
||||||
|
# CHECK: }
|
||||||
|
def test_sparse_sum():
|
||||||
|
|
||||||
|
class SumNet(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SumNet, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
|
||||||
|
dense_input = torch.ones(64, 64)
|
||||||
|
sparse_input = dense_input.to_sparse_csr()
|
||||||
|
m = export_and_import(SumNet(), sparse_input)
|
||||||
|
print(m)
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_sparse_SpMM
|
||||||
|
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
|
||||||
|
# CHECK: func.func @main(
|
||||||
|
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
|
||||||
|
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
|
||||||
|
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
|
||||||
|
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
|
||||||
|
# CHECK: }
|
||||||
|
def test_sparse_SpMM():
|
||||||
|
|
||||||
|
class MatMulNet(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MatMulNet, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.matmul(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
dense_input = torch.ones(64, 64)
|
||||||
|
sparse_input = dense_input.to_sparse_coo()
|
||||||
|
m = export_and_import(MatMulNet(), sparse_input, dense_input)
|
||||||
|
print(m)
|
Loading…
Reference in New Issue