mirror of https://github.com/llvm/torch-mlir
131 lines
4.4 KiB
Python
131 lines
4.4 KiB
Python
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
# Also available under a BSD-style license. See LICENSE.
|
||
|
|
||
|
# RUN: %PYTHON %s | FileCheck %s
|
||
|
|
||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.export
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from torch_mlir.extras.fx_importer import FxImporter
|
||
|
from torch_mlir import ir
|
||
|
from torch_mlir.dialects import torch as torch_d
|
||
|
|
||
|
|
||
|
# All sparse layouts currently supported in torch.sparse.
|
||
|
SPARSE_LAYOUTS = [
|
||
|
torch.sparse_coo,
|
||
|
torch.sparse_csr,
|
||
|
torch.sparse_csc,
|
||
|
torch.sparse_bsr,
|
||
|
torch.sparse_bsc
|
||
|
]
|
||
|
|
||
|
|
||
|
def sparse_export(f: Callable,
|
||
|
args: Tuple[Any, ...],
|
||
|
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
|
||
|
"""
|
||
|
This is a ***temporary*** wrapper around `torch.export.export`
|
||
|
that eventually should be removed and simply replaced by the
|
||
|
standard API for exporting traced graphs.
|
||
|
|
||
|
But until issue
|
||
|
|
||
|
https://github.com/pytorch/pytorch/pull/117907
|
||
|
|
||
|
is addressed, this wrapper provides support for the sparse
|
||
|
tensor types by first converting all operands to dense tensors,
|
||
|
building the traced graph as for the dense case, and then
|
||
|
annotation sparse parameters with their actual sparse layout
|
||
|
attributes. This temporary solution accelerates testing
|
||
|
torch-mlir with PyTorch sparse tensors until the issue is
|
||
|
resovled.
|
||
|
"""
|
||
|
# Convert all arguments to dense.
|
||
|
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
|
||
|
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
|
||
|
# Build the regular FX traced graph with only dense arguments
|
||
|
# (the current version would crash otherwise, see issue above).
|
||
|
prog = torch.export.export(f, dargs, kwargs, constraints=None)
|
||
|
# Annotate sparse arguments in the graph.
|
||
|
alen = len(args)
|
||
|
for i, node in enumerate(prog.graph.nodes):
|
||
|
if node.op == "placeholder" and i < alen and mask[i]:
|
||
|
node.meta['sparsity'] = args[i].layout
|
||
|
# TODO: annotate inputs to change calling conventions!
|
||
|
return prog
|
||
|
|
||
|
|
||
|
def export_and_import(f, *args, **kwargs):
|
||
|
"""This method implements Stella's importer, stripped down to essentials."""
|
||
|
context = ir.Context()
|
||
|
torch_d.register_dialect(context)
|
||
|
fx_importer = FxImporter(context=context)
|
||
|
prog = sparse_export(f, args, kwargs)
|
||
|
fx_importer.import_frozen_exported_program(prog)
|
||
|
return fx_importer.module_op
|
||
|
|
||
|
|
||
|
def run(f):
|
||
|
print(f"{f.__name__}")
|
||
|
print("-" * len(f.__name__))
|
||
|
f()
|
||
|
print()
|
||
|
|
||
|
|
||
|
@run
|
||
|
# CHECK-LABEL: test_sparse_sum
|
||
|
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||
|
# CHECK: func.func @main(
|
||
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
|
||
|
# CHECK: %[[N:.*]] = torch.constant.none
|
||
|
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
|
||
|
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
|
||
|
# CHECK: }
|
||
|
def test_sparse_sum():
|
||
|
|
||
|
class SumNet(torch.nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super(SumNet, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x.sum()
|
||
|
|
||
|
|
||
|
dense_input = torch.ones(64, 64)
|
||
|
sparse_input = dense_input.to_sparse_csr()
|
||
|
m = export_and_import(SumNet(), sparse_input)
|
||
|
print(m)
|
||
|
|
||
|
|
||
|
@run
|
||
|
# CHECK-LABEL: test_sparse_SpMM
|
||
|
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
|
||
|
# CHECK: func.func @main(
|
||
|
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
|
||
|
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
|
||
|
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
|
||
|
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
|
||
|
# CHECK: }
|
||
|
def test_sparse_SpMM():
|
||
|
|
||
|
class MatMulNet(torch.nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super(MatMulNet, self).__init__()
|
||
|
|
||
|
def forward(self, x, y):
|
||
|
return torch.matmul(x, y)
|
||
|
|
||
|
|
||
|
dense_input = torch.ones(64, 64)
|
||
|
sparse_input = dense_input.to_sparse_coo()
|
||
|
m = export_and_import(MatMulNet(), sparse_input, dense_input)
|
||
|
print(m)
|