2024-01-31 13:22:12 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
|
2024-03-15 23:29:48 +08:00
|
|
|
from typing import Any, Callable, Optional, Tuple, Dict
|
2024-01-31 13:22:12 +08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.export
|
|
|
|
import torch.nn as nn
|
2024-04-13 00:56:32 +08:00
|
|
|
import numpy as np
|
2024-01-31 13:22:12 +08:00
|
|
|
|
|
|
|
from torch_mlir.extras.fx_importer import FxImporter
|
2024-02-13 08:10:57 +08:00
|
|
|
from torch_mlir.extras.fx_importer import SparsityMeta
|
2024-01-31 13:22:12 +08:00
|
|
|
from torch_mlir import ir
|
|
|
|
from torch_mlir.dialects import torch as torch_d
|
2024-02-13 02:04:54 +08:00
|
|
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
|
|
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
|
|
|
RefBackendLinalgOnTensorsBackend,
|
|
|
|
)
|
2024-01-31 13:22:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
# All sparse layouts currently supported in torch.sparse.
|
|
|
|
SPARSE_LAYOUTS = [
|
|
|
|
torch.sparse_coo,
|
|
|
|
torch.sparse_csr,
|
|
|
|
torch.sparse_csc,
|
|
|
|
torch.sparse_bsr,
|
2024-02-13 02:04:54 +08:00
|
|
|
torch.sparse_bsc,
|
2024-01-31 13:22:12 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2024-02-13 08:10:57 +08:00
|
|
|
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
2024-02-24 03:57:20 +08:00
|
|
|
"""
|
|
|
|
Returns a meta data tuple for the given sparse tensor.
|
|
|
|
|
|
|
|
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
|
|
|
|
"""
|
2024-02-13 08:10:57 +08:00
|
|
|
sparse_dim = a.sparse_dim()
|
|
|
|
dense_dim = a.dense_dim()
|
|
|
|
batch_dim = a.ndim - dense_dim - sparse_dim
|
2024-02-24 03:57:20 +08:00
|
|
|
blocksize = None
|
2024-02-13 02:04:54 +08:00
|
|
|
if a.layout is torch.sparse_coo:
|
2024-02-13 08:10:57 +08:00
|
|
|
return SparsityMeta(
|
2024-02-13 02:04:54 +08:00
|
|
|
a.layout,
|
2024-02-13 08:10:57 +08:00
|
|
|
batch_dim,
|
|
|
|
sparse_dim,
|
|
|
|
dense_dim,
|
2024-02-24 03:57:20 +08:00
|
|
|
blocksize,
|
2024-04-09 07:46:51 +08:00
|
|
|
a._indices().dtype,
|
|
|
|
a._indices().dtype,
|
2024-02-13 02:04:54 +08:00
|
|
|
)
|
|
|
|
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
2024-02-24 03:57:20 +08:00
|
|
|
if a.layout is torch.sparse_bsr:
|
|
|
|
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
2024-02-13 08:10:57 +08:00
|
|
|
return SparsityMeta(
|
2024-02-13 02:04:54 +08:00
|
|
|
a.layout,
|
2024-02-13 08:10:57 +08:00
|
|
|
batch_dim,
|
|
|
|
sparse_dim,
|
|
|
|
dense_dim,
|
2024-02-24 03:57:20 +08:00
|
|
|
blocksize,
|
|
|
|
a.crow_indices().dtype,
|
|
|
|
a.col_indices().dtype,
|
2024-02-13 02:04:54 +08:00
|
|
|
)
|
|
|
|
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
2024-02-24 03:57:20 +08:00
|
|
|
if a.layout is torch.sparse_bsc:
|
|
|
|
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
|
2024-02-13 08:10:57 +08:00
|
|
|
return SparsityMeta(
|
2024-02-13 02:04:54 +08:00
|
|
|
a.layout,
|
2024-02-13 08:10:57 +08:00
|
|
|
batch_dim,
|
|
|
|
sparse_dim,
|
|
|
|
dense_dim,
|
2024-02-24 03:57:20 +08:00
|
|
|
blocksize,
|
|
|
|
a.ccol_indices().dtype,
|
|
|
|
a.row_indices().dtype,
|
2024-02-13 02:04:54 +08:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unsupported sparse layout for {a}")
|
|
|
|
|
|
|
|
|
|
|
|
def sparse_export(
|
2024-03-15 23:29:48 +08:00
|
|
|
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
2024-02-13 02:04:54 +08:00
|
|
|
) -> torch.export.ExportedProgram:
|
2024-01-31 13:22:12 +08:00
|
|
|
"""
|
|
|
|
This is a ***temporary*** wrapper around `torch.export.export`
|
|
|
|
that eventually should be removed and simply replaced by the
|
|
|
|
standard API for exporting traced graphs.
|
|
|
|
|
|
|
|
But until issue
|
|
|
|
|
|
|
|
https://github.com/pytorch/pytorch/pull/117907
|
|
|
|
|
|
|
|
is addressed, this wrapper provides support for the sparse
|
|
|
|
tensor types by first converting all operands to dense tensors,
|
2024-05-08 06:27:36 +08:00
|
|
|
building the traced graph as for the dense case, then annotating
|
|
|
|
sparse parameters with their actual sparse layout attributes,
|
|
|
|
followed by some simple propagation rules. This temporary solution
|
|
|
|
accelerates testing torch-mlir with PyTorch sparse tensors until
|
|
|
|
the issue is resolved upstream.
|
2024-01-31 13:22:12 +08:00
|
|
|
"""
|
|
|
|
# Convert all arguments to dense.
|
2024-02-13 02:04:54 +08:00
|
|
|
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
|
|
|
|
mask = [a.layout in SPARSE_LAYOUTS for a in args]
|
2024-01-31 13:22:12 +08:00
|
|
|
# Build the regular FX traced graph with only dense arguments
|
|
|
|
# (the current version would crash otherwise, see issue above).
|
2024-03-08 00:12:38 +08:00
|
|
|
prog = torch.export.export(f, dargs, kwargs)
|
2024-05-08 06:27:36 +08:00
|
|
|
# Annotate sparse arguments in the graph and apply some very
|
|
|
|
# basic propagation rules for sparsity.
|
2024-02-13 08:10:57 +08:00
|
|
|
specs = prog.graph_signature.input_specs
|
|
|
|
alen = len(specs)
|
|
|
|
k = 0
|
2024-01-31 13:22:12 +08:00
|
|
|
for i, node in enumerate(prog.graph.nodes):
|
2024-05-08 06:27:36 +08:00
|
|
|
if node.op == "placeholder":
|
|
|
|
# Argument.
|
|
|
|
spec = specs[i]
|
|
|
|
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
|
|
|
|
if mask[k]:
|
|
|
|
node.meta["sparsity"] = sparse_metadata(args[k])
|
|
|
|
k = k + 1
|
|
|
|
elif node.op == "call_function":
|
2024-05-10 03:34:14 +08:00
|
|
|
# TODO: use upstream _opname implementation when available
|
|
|
|
opname = node.target._schema.name.split("::")[1]
|
2024-05-08 06:27:36 +08:00
|
|
|
# Zero preserving elt-wise unary op.
|
2024-05-10 03:34:14 +08:00
|
|
|
if opname in {"abs", "neg", "relu", "sin"}:
|
2024-05-08 06:27:36 +08:00
|
|
|
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
2024-05-10 03:34:14 +08:00
|
|
|
elif opname == "_to_sparse":
|
2024-05-09 10:01:24 +08:00
|
|
|
dim = len(node.meta.get("val").shape)
|
|
|
|
node.meta["sparsity"] = SparsityMeta(
|
|
|
|
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
|
|
|
)
|
2024-05-10 01:03:25 +08:00
|
|
|
# TODO: Uncomment this to hack sparsity into the network.
|
2024-05-10 03:34:14 +08:00
|
|
|
# elif opname == "_to_dense":
|
2024-05-10 01:03:25 +08:00
|
|
|
# # hack (assumes we never really want the to_dense for now)
|
|
|
|
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
2024-05-16 01:09:27 +08:00
|
|
|
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
|
|
|
dim = len(node.meta.get("val").shape)
|
|
|
|
node.meta["sparsity"] = SparsityMeta(
|
|
|
|
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
|
|
|
)
|
|
|
|
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
|
|
|
|
dim = len(node.meta.get("val").shape)
|
|
|
|
node.meta["sparsity"] = SparsityMeta(
|
|
|
|
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
|
|
|
|
)
|
2024-01-31 13:22:12 +08:00
|
|
|
return prog
|
|
|
|
|
|
|
|
|
|
|
|
def export_and_import(f, *args, **kwargs):
|
|
|
|
"""This method implements Stella's importer, stripped down to essentials."""
|
|
|
|
context = ir.Context()
|
|
|
|
torch_d.register_dialect(context)
|
|
|
|
fx_importer = FxImporter(context=context)
|
|
|
|
prog = sparse_export(f, args, kwargs)
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
fx_importer.import_frozen_program(prog)
|
2024-02-13 02:04:54 +08:00
|
|
|
return fx_importer.module
|
|
|
|
|
|
|
|
|
|
|
|
def sparse_jit(f, *args, **kwargs):
|
|
|
|
"""This method compiles and runs the given callable using linalg backend."""
|
|
|
|
# Import module and lower into Linalg IR.
|
2024-04-18 05:44:05 +08:00
|
|
|
module = export_and_import(f, *args, **kwargs)
|
2024-02-13 02:04:54 +08:00
|
|
|
run_pipeline_with_repro_report(
|
|
|
|
module,
|
2024-02-28 03:49:32 +08:00
|
|
|
(
|
|
|
|
"builtin.module("
|
|
|
|
"func.func(torch-decompose-complex-ops),"
|
|
|
|
"torch-backend-to-linalg-on-tensors-backend-pipeline)"
|
|
|
|
),
|
2024-02-13 02:04:54 +08:00
|
|
|
"Lowering TorchFX IR -> Linalg IR",
|
|
|
|
enable_ir_printing=False,
|
|
|
|
)
|
|
|
|
# Compile with reference Linalg backend.
|
|
|
|
backend = RefBackendLinalgOnTensorsBackend()
|
|
|
|
compiled = backend.compile(module)
|
|
|
|
invoker = backend.load(compiled)
|
2024-04-18 05:44:05 +08:00
|
|
|
xargs = []
|
|
|
|
# Prepare the buffer parameters (assume all dense).
|
|
|
|
# TODO: filters out scalar arguments, anything else?
|
|
|
|
params = dict(f.named_buffers(remove_duplicate=True))
|
|
|
|
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
|
|
|
|
for p in params_flat:
|
|
|
|
if len(p.shape) > 0:
|
2024-04-28 05:16:31 +08:00
|
|
|
xargs.append(p.numpy())
|
2024-02-13 02:04:54 +08:00
|
|
|
# Prepare input parameters. Sparse input tensors are split into
|
|
|
|
# their composite tensors. All PyTorch tensors are converted
|
2024-04-13 00:56:32 +08:00
|
|
|
# to their backing numpy arrays. Note that the output consists
|
|
|
|
# of numpy arrays as well, which can trivially be reconstructed
|
|
|
|
# into PyTorch tensors (dense and sparse).
|
2024-02-13 02:04:54 +08:00
|
|
|
for a in args:
|
|
|
|
if a.layout is torch.sparse_coo:
|
2024-02-29 08:08:37 +08:00
|
|
|
# Construct the additional position array required by MLIR with data
|
2024-04-13 00:56:32 +08:00
|
|
|
# array([0, nnz]). The COO format always uses int64 indices.
|
|
|
|
xargs.append(np.array([0, a._nnz()], dtype=np.int64))
|
2024-05-08 06:27:36 +08:00
|
|
|
# Transform a tensor<ndim x nnz> into ndim x tensor<nnz> to conform
|
|
|
|
# to the MLIR SoA COO representation.
|
2024-04-10 02:21:30 +08:00
|
|
|
for idx in a._indices():
|
2024-02-29 08:08:37 +08:00
|
|
|
xargs.append(idx.numpy())
|
2024-04-10 02:21:30 +08:00
|
|
|
xargs.append(a._values().numpy())
|
2024-04-02 07:34:59 +08:00
|
|
|
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
|
2024-02-13 02:04:54 +08:00
|
|
|
xargs.append(a.crow_indices().numpy())
|
|
|
|
xargs.append(a.col_indices().numpy())
|
|
|
|
xargs.append(a.values().numpy())
|
2024-04-02 07:34:59 +08:00
|
|
|
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
|
2024-02-13 02:04:54 +08:00
|
|
|
xargs.append(a.ccol_indices().numpy())
|
|
|
|
xargs.append(a.row_indices().numpy())
|
2024-04-02 07:34:59 +08:00
|
|
|
xargs.append(a.values().numpy())
|
2024-02-13 02:04:54 +08:00
|
|
|
else:
|
|
|
|
xargs.append(a.numpy())
|
|
|
|
# Invoke.
|
|
|
|
return invoker.main(*xargs)
|
2024-01-31 13:22:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
def run(f):
|
|
|
|
print(f"{f.__name__}")
|
|
|
|
print("-" * len(f.__name__))
|
|
|
|
f()
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
2024-04-10 02:21:30 +08:00
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-04-10 02:21:30 +08:00
|
|
|
# CHECK-LABEL: test_sparse_id
|
|
|
|
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
|
|
|
# CHECK: func.func @main(
|
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> {
|
|
|
|
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
|
|
|
|
# CHECK: }
|
|
|
|
#
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: torch.sparse
|
|
|
|
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
|
|
|
|
# CHECK: [ 0, 1, 10, 19]{{\]}}),
|
|
|
|
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
|
|
|
|
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
|
|
|
|
# CHECK: torch.mlir
|
|
|
|
# CHECK: [0 4]
|
|
|
|
# CHECK: [0 1 2 9]
|
|
|
|
# CHECK: [ 0 1 10 19]
|
|
|
|
# CHECK: [-1000. -1. 1. 1000.]
|
2024-04-10 02:21:30 +08:00
|
|
|
#
|
|
|
|
def test_sparse_id():
|
|
|
|
class IdNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(IdNet, self).__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
net = IdNet()
|
|
|
|
idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]])
|
|
|
|
val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64)
|
|
|
|
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20])
|
|
|
|
m = export_and_import(net, sparse_input)
|
|
|
|
print(m)
|
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input)
|
2024-04-13 00:56:32 +08:00
|
|
|
res2 = sparse_jit(net, sparse_input)
|
2024-04-10 02:21:30 +08:00
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
2024-05-08 06:27:36 +08:00
|
|
|
print(res2[0])
|
|
|
|
print(res2[1])
|
|
|
|
print(res2[2])
|
|
|
|
print(res2[3])
|
2024-04-10 02:21:30 +08:00
|
|
|
|
|
|
|
|
2024-01-31 13:22:12 +08:00
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-01-31 13:22:12 +08:00
|
|
|
# CHECK-LABEL: test_sparse_sum
|
2024-02-13 02:04:54 +08:00
|
|
|
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
|
2024-01-31 13:22:12 +08:00
|
|
|
# CHECK: func.func @main(
|
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
|
|
|
|
# CHECK: %[[N:.*]] = torch.constant.none
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
|
|
|
|
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
|
|
|
|
# CHECK: }
|
2024-02-13 02:04:54 +08:00
|
|
|
#
|
|
|
|
# CHECK: torch.sparse = tensor(4096.)
|
|
|
|
# CHECK: torch.mlir = 4096.0
|
|
|
|
#
|
2024-01-31 13:22:12 +08:00
|
|
|
def test_sparse_sum():
|
|
|
|
class SumNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(SumNet, self).__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x.sum()
|
|
|
|
|
2024-02-28 03:49:32 +08:00
|
|
|
net = SumNet()
|
2024-02-13 02:04:54 +08:00
|
|
|
dense_input = torch.ones(64, 64)
|
2024-01-31 13:22:12 +08:00
|
|
|
sparse_input = dense_input.to_sparse_csr()
|
2024-02-28 03:49:32 +08:00
|
|
|
m = export_and_import(net, sparse_input)
|
2024-01-31 13:22:12 +08:00
|
|
|
print(m)
|
|
|
|
|
2024-02-13 02:04:54 +08:00
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input)
|
|
|
|
res2 = sparse_jit(net, sparse_input)
|
|
|
|
print("torch.sparse =", res1)
|
|
|
|
print("torch.mlir =", res2)
|
|
|
|
|
2024-01-31 13:22:12 +08:00
|
|
|
|
2024-02-24 03:57:20 +08:00
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-02-24 03:57:20 +08:00
|
|
|
# CHECK-LABEL: test_sparse_SpMV
|
|
|
|
# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }>
|
|
|
|
# CHECK: func.func @main(
|
|
|
|
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>,
|
|
|
|
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> {
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
|
|
|
# CHECK: return %[[R]] : !torch.vtensor<[10],f32>
|
|
|
|
# CHECK: }
|
2024-02-28 03:49:32 +08:00
|
|
|
#
|
|
|
|
# CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.])
|
|
|
|
# CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.]
|
|
|
|
#
|
2024-02-24 03:57:20 +08:00
|
|
|
def test_sparse_SpMV():
|
|
|
|
class SpMVNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(SpMVNet, self).__init__()
|
|
|
|
|
|
|
|
def forward(self, x, v):
|
|
|
|
return torch.mv(x, v)
|
|
|
|
|
2024-02-28 03:49:32 +08:00
|
|
|
net = SpMVNet()
|
|
|
|
dense_vector = torch.arange(1, 11, dtype=torch.float32)
|
2024-02-24 03:57:20 +08:00
|
|
|
dense_input = torch.ones(10, 10)
|
|
|
|
sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2))
|
2024-02-28 03:49:32 +08:00
|
|
|
m = export_and_import(net, sparse_input, dense_vector)
|
2024-02-24 03:57:20 +08:00
|
|
|
print(m)
|
|
|
|
|
2024-02-28 03:49:32 +08:00
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input, dense_vector)
|
|
|
|
res2 = sparse_jit(net, sparse_input, dense_vector)
|
|
|
|
print("torch.sparse =", res1)
|
|
|
|
print("torch.mlir =", res2)
|
|
|
|
|
2024-02-24 03:57:20 +08:00
|
|
|
|
2024-01-31 13:22:12 +08:00
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: torch.sparse
|
|
|
|
# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.],
|
|
|
|
# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.],
|
|
|
|
# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}})
|
|
|
|
# CHECK: torch.mlir
|
|
|
|
# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.]
|
|
|
|
# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.]
|
|
|
|
# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}}
|
2024-02-13 02:04:54 +08:00
|
|
|
#
|
2024-01-31 13:22:12 +08:00
|
|
|
def test_sparse_SpMM():
|
|
|
|
class MatMulNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(MatMulNet, self).__init__()
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
2024-02-13 02:04:54 +08:00
|
|
|
return torch.matmul(x, y)
|
2024-01-31 13:22:12 +08:00
|
|
|
|
2024-02-28 03:49:32 +08:00
|
|
|
net = MatMulNet()
|
2024-02-13 02:04:54 +08:00
|
|
|
dense_input = torch.ones(8, 8)
|
2024-01-31 13:22:12 +08:00
|
|
|
sparse_input = dense_input.to_sparse_coo()
|
2024-02-28 03:49:32 +08:00
|
|
|
m = export_and_import(net, sparse_input, dense_input)
|
2024-06-07 00:53:40 +08:00
|
|
|
# print(m)
|
2024-02-13 02:04:54 +08:00
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input, dense_input)
|
|
|
|
res2 = sparse_jit(net, sparse_input, dense_input)
|
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
|
|
|
print(res2)
|
2024-02-13 08:10:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-02-13 08:10:57 +08:00
|
|
|
# CHECK-LABEL: test_sparse_eltwise
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
|
2024-02-13 08:10:57 +08:00
|
|
|
# CHECK: func.func @main(
|
2024-05-09 12:18:17 +08:00
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
|
|
|
|
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
|
2024-02-13 08:10:57 +08:00
|
|
|
# CHECK: }
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
|
2024-02-13 08:10:57 +08:00
|
|
|
# CHECK: func.func @main(
|
2024-05-09 12:18:17 +08:00
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
|
|
|
|
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
|
2024-02-13 08:10:57 +08:00
|
|
|
# CHECK: }
|
2024-02-14 05:42:56 +08:00
|
|
|
#
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: torch.sparse
|
2024-05-09 12:18:17 +08:00
|
|
|
# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
|
|
|
|
# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: values=tensor({{\[}}[ -1., -2.],
|
2024-05-10 01:03:25 +08:00
|
|
|
# CHECK: [ -3., -4.],
|
|
|
|
# CHECK: [ -5., -6.],
|
|
|
|
# CHECK: [ -7., -8.],
|
|
|
|
# CHECK: [ -9., -10.],
|
|
|
|
# CHECK: [-11., -12.],
|
|
|
|
# CHECK: [-13., -14.],
|
2024-05-09 12:18:17 +08:00
|
|
|
# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: layout=torch.sparse_csr)
|
|
|
|
# CHECK: torch.mlir
|
2024-05-09 12:18:17 +08:00
|
|
|
# CHECK: [0 2 4 6 8]
|
|
|
|
# CHECK: [0 1 0 1 0 1 0 1]
|
|
|
|
# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
|
|
|
|
# CHECK: -15. -16.]
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: torch.mlir.batch
|
2024-02-28 03:49:32 +08:00
|
|
|
#
|
2024-02-13 08:10:57 +08:00
|
|
|
def test_sparse_eltwise():
|
|
|
|
class EltNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(EltNet, self).__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return -x
|
|
|
|
|
2024-02-28 03:49:32 +08:00
|
|
|
net = EltNet()
|
2024-02-14 05:42:56 +08:00
|
|
|
dense_input = torch.reshape(
|
2024-05-09 12:18:17 +08:00
|
|
|
torch.arange(1, 17, dtype=torch.float32), shape=(4, 2, 2)
|
2024-02-14 05:42:56 +08:00
|
|
|
)
|
2024-02-13 08:10:57 +08:00
|
|
|
|
|
|
|
# This yields a plain CSR with dense **sub**tensor
|
|
|
|
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
|
2024-02-28 03:49:32 +08:00
|
|
|
m = export_and_import(net, sparse_input)
|
2024-02-13 08:10:57 +08:00
|
|
|
print(m)
|
2024-02-14 05:42:56 +08:00
|
|
|
|
2024-05-08 06:27:36 +08:00
|
|
|
# This yields a **batched** CSR.
|
|
|
|
batch_input = dense_input.to_sparse_csr(dense_dim=0)
|
|
|
|
m = export_and_import(net, batch_input)
|
|
|
|
print(m)
|
|
|
|
|
2024-02-14 05:42:56 +08:00
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input)
|
2024-05-09 12:18:17 +08:00
|
|
|
res2 = sparse_jit(net, sparse_input)
|
2024-05-10 01:03:25 +08:00
|
|
|
# TODO: make this work
|
2024-05-08 06:27:36 +08:00
|
|
|
# res3 = sparse_jit(net, batch_input)
|
2024-02-14 05:42:56 +08:00
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
2024-05-09 12:18:17 +08:00
|
|
|
print(res2[0])
|
|
|
|
print(res2[1])
|
|
|
|
print(res2[2])
|
2024-04-02 07:34:59 +08:00
|
|
|
print("torch.mlir.batch")
|
2024-04-09 07:46:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-04-09 07:46:51 +08:00
|
|
|
# CHECK-LABEL: test_sparse_coo3
|
|
|
|
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
|
|
|
# CHECK: func.func @main(
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#[[$COO3]]>) -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> {
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
|
|
|
|
# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]>
|
2024-04-09 07:46:51 +08:00
|
|
|
# CHECK: }
|
|
|
|
#
|
2024-05-08 06:27:36 +08:00
|
|
|
# CHECK: torch.sparse
|
|
|
|
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 1, 4, 9, 9],
|
|
|
|
# CHECK: [ 0, 1, 1, 5, 19, 19],
|
|
|
|
# CHECK: [ 0, 1, 3, 6, 28, 29]{{\]}}),
|
|
|
|
# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]),
|
|
|
|
# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo)
|
|
|
|
# CHECK: torch.mlir
|
2024-05-14 06:34:26 +08:00
|
|
|
# CHECK: [0 6]
|
|
|
|
# CHECK: [0 1 1 4 9 9]
|
|
|
|
# CHECK: [ 0 1 1 5 19 19]
|
|
|
|
# CHECK: [ 0 1 3 6 28 29]
|
|
|
|
# CHECK: [ 0. 0. 1. 2. 3. 1000.]
|
2024-04-09 07:46:51 +08:00
|
|
|
#
|
|
|
|
def test_sparse_coo3():
|
|
|
|
class COO3Net(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(COO3Net, self).__init__()
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
net = COO3Net()
|
|
|
|
|
|
|
|
# Direct 3-dim COO construction.
|
|
|
|
idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]])
|
|
|
|
val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64)
|
|
|
|
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30])
|
|
|
|
|
|
|
|
m = export_and_import(net, sparse_input)
|
|
|
|
print(m)
|
2024-05-08 06:27:36 +08:00
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(sparse_input)
|
2024-05-14 06:34:26 +08:00
|
|
|
res2 = sparse_jit(net, sparse_input)
|
2024-05-08 06:27:36 +08:00
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
2024-05-14 06:34:26 +08:00
|
|
|
print(res2[0])
|
|
|
|
print(res2[1])
|
|
|
|
print(res2[2])
|
|
|
|
print(res2[3])
|
|
|
|
print(res2[4])
|
2024-05-09 10:01:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
2024-05-09 10:01:24 +08:00
|
|
|
# CHECK: torch.sparse
|
|
|
|
# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1],
|
|
|
|
# CHECK: [0, 0, 1, 1, 0, 0, 1, 1],
|
|
|
|
# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}),
|
|
|
|
# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
|
|
|
|
# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo)
|
|
|
|
#
|
|
|
|
def test_sparse_activation():
|
|
|
|
class SparseActivationCOO(torch.nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
return x.to_sparse()
|
|
|
|
|
|
|
|
net = SparseActivationCOO()
|
|
|
|
x = torch.ones(2, 2, 2)
|
|
|
|
m = export_and_import(net, x)
|
2024-06-07 00:53:40 +08:00
|
|
|
# print(m)
|
2024-05-09 10:01:24 +08:00
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(x)
|
2024-06-07 00:53:40 +08:00
|
|
|
# res2 = sparse_jit(net, x)
|
2024-05-09 10:01:24 +08:00
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
2024-06-07 00:53:40 +08:00
|
|
|
# print("torch.mlir")
|
|
|
|
# print(res2[0])
|
|
|
|
# print(res2[1])
|
|
|
|
# print(res2[2])
|
|
|
|
# print(res2[3])
|
|
|
|
# print(res2[4])
|
2024-05-09 12:18:42 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
#
|
|
|
|
# CHECK-LABEL: test_sparse_network
|
|
|
|
# CHECK: func.func @main(
|
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
|
|
|
|
# ... lots of IR ...
|
|
|
|
# CHECK-COUNT-15: torch.aten.mul.Tensor
|
|
|
|
# ... lots of IR ...
|
|
|
|
# CHECK: }
|
|
|
|
#
|
|
|
|
# CHECK: torch.sparse
|
2024-05-10 01:03:25 +08:00
|
|
|
# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
|
2024-05-09 12:18:42 +08:00
|
|
|
#
|
|
|
|
def test_sparse_network():
|
|
|
|
def spike(input):
|
|
|
|
return (input >= 0).float()
|
|
|
|
|
|
|
|
def sqSum(input):
|
|
|
|
return (input * input).sum()
|
|
|
|
|
|
|
|
class LIF(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(LIF, self).__init__()
|
|
|
|
self.thresh = 1.0
|
|
|
|
self.decay = 0.5
|
|
|
|
self.act = spike
|
|
|
|
|
|
|
|
def forward(self, X):
|
|
|
|
"""A filter that yields a binary-valued sparse tensor."""
|
|
|
|
mem = 0
|
|
|
|
spike_pot = []
|
|
|
|
T = X.size(-1)
|
|
|
|
for t in range(T):
|
|
|
|
mem = mem * self.decay + X[..., t]
|
|
|
|
spike = self.act(mem - self.thresh)
|
2024-05-10 01:03:25 +08:00
|
|
|
spike = spike.to_sparse().to_dense() # prop hack
|
2024-05-15 00:10:36 +08:00
|
|
|
mem = mem * (1.0 - spike)
|
2024-05-09 12:18:42 +08:00
|
|
|
spike_pot.append(spike)
|
|
|
|
spike_pot = torch.stack(spike_pot, dim=-1)
|
|
|
|
return spike_pot
|
|
|
|
|
|
|
|
class tdLayer(nn.Module):
|
|
|
|
def __init__(self, layer):
|
|
|
|
super(tdLayer, self).__init__()
|
|
|
|
self.layer = layer
|
|
|
|
|
|
|
|
def forward(self, X):
|
|
|
|
T = X.size(-1)
|
|
|
|
out = []
|
|
|
|
for t in range(T):
|
|
|
|
m = self.layer(X[..., t])
|
|
|
|
out.append(m)
|
|
|
|
out = torch.stack(out, dim=-1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(Block, self).__init__()
|
|
|
|
self.spike = LIF()
|
|
|
|
self.layer = tdLayer(sqSum)
|
|
|
|
|
|
|
|
def forward(self, X):
|
|
|
|
out = self.spike(X)
|
|
|
|
out = self.layer(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
net = Block()
|
2024-05-10 01:03:25 +08:00
|
|
|
|
|
|
|
# Get a random (but reproducible) input, so that a
|
|
|
|
# general sparse tensor appears after LIF.
|
|
|
|
torch.manual_seed(0)
|
|
|
|
x = torch.rand(2, 3, 8, 8)
|
2024-05-09 12:18:42 +08:00
|
|
|
m = export_and_import(net, x)
|
|
|
|
print(m)
|
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(x)
|
2024-06-07 00:53:40 +08:00
|
|
|
# res2 = sparse_jit(net, x)
|
2024-05-09 12:18:42 +08:00
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
2024-06-07 00:53:40 +08:00
|
|
|
# print("torch.mlir")
|
|
|
|
# print(res2)
|
2024-05-15 03:13:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
#
|
|
|
|
# CHECK: torch.sparse
|
|
|
|
# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
|
|
|
|
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
|
|
|
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
|
|
|
|
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
|
|
|
|
# CHECK: torch.mlir
|
|
|
|
#
|
|
|
|
def test_sparse_feature_scaling():
|
|
|
|
class Scale(nn.Module):
|
|
|
|
def forward(self, F):
|
|
|
|
sum_vector = torch.sum(F, dim=1)
|
|
|
|
reciprocal_vector = 1 / sum_vector
|
|
|
|
reciprocal_vector[reciprocal_vector == float("inf")] = 0
|
|
|
|
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
|
|
|
|
return scaling_diagonal @ F
|
|
|
|
|
|
|
|
net = Scale()
|
|
|
|
|
|
|
|
# Get a random (but reproducible) features input.
|
|
|
|
torch.manual_seed(0)
|
|
|
|
f = torch.rand(4, 4)
|
|
|
|
m = export_and_import(net, f)
|
2024-06-07 00:53:40 +08:00
|
|
|
# print(m)
|
2024-05-15 03:13:54 +08:00
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
|
|
|
res1 = net(f)
|
|
|
|
# TODO: make this work
|
|
|
|
# res2 = sparse_jit(net, f)
|
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
2024-05-18 06:43:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
#
|
|
|
|
# CHECK-LABEL: test_sparse_gcn
|
|
|
|
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
|
|
|
# CHECK: func.func @main(
|
|
|
|
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>,
|
|
|
|
# CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> {
|
|
|
|
# CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_4_torch.float32> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
|
|
|
|
# CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
|
|
|
|
# CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
|
|
|
|
# CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
|
|
|
# CHECK: %[[ONE:.*]] = torch.constant.int 1
|
|
|
|
# CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32>
|
|
|
|
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
|
|
|
|
# CHECK: }
|
|
|
|
#
|
|
|
|
# CHECK: torch.sparse
|
2024-05-22 08:12:55 +08:00
|
|
|
# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778],
|
|
|
|
# CHECK: [5.7502, 5.7502, 5.7502, 5.7502],
|
|
|
|
# CHECK: [4.6980, 4.6980, 4.6980, 4.6980],
|
|
|
|
# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}})
|
2024-05-18 06:43:50 +08:00
|
|
|
# CHECK: torch.mlir
|
2024-05-22 08:12:55 +08:00
|
|
|
# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
|
|
|
|
# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
|
|
|
|
# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
|
|
|
|
# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
|
2024-05-18 06:43:50 +08:00
|
|
|
#
|
|
|
|
def test_sparse_gcn():
|
|
|
|
class GraphConv(nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim):
|
|
|
|
super(GraphConv, self).__init__()
|
|
|
|
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
|
2024-05-22 08:12:55 +08:00
|
|
|
nn.init.ones_(self.kernel)
|
2024-05-18 06:43:50 +08:00
|
|
|
self.bias = nn.Parameter(torch.Tensor(output_dim))
|
|
|
|
nn.init.ones_(self.bias)
|
|
|
|
|
|
|
|
def forward(self, inp, adj_mat):
|
|
|
|
# Input matrix times weight matrix.
|
|
|
|
support = torch.mm(inp, self.kernel)
|
|
|
|
# Sparse adjacency matrix times support matrix.
|
|
|
|
output = torch.spmm(adj_mat, support)
|
|
|
|
# Add bias.
|
|
|
|
output = output + self.bias
|
|
|
|
return output
|
|
|
|
|
|
|
|
net = GraphConv(4, 4)
|
|
|
|
|
|
|
|
# Get a random (but reproducible) matrices.
|
|
|
|
torch.manual_seed(0)
|
|
|
|
inp = torch.rand(4, 4)
|
|
|
|
adj_mat = torch.rand(4, 4).to_sparse()
|
|
|
|
m = export_and_import(net, inp, adj_mat)
|
|
|
|
print(m)
|
|
|
|
|
|
|
|
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
2024-05-21 10:52:16 +08:00
|
|
|
# Set to inference mode to avoid autograd component in result.
|
|
|
|
with torch.no_grad():
|
|
|
|
res1 = net(inp, adj_mat)
|
|
|
|
res2 = sparse_jit(net, inp, adj_mat)
|
|
|
|
print("torch.sparse")
|
|
|
|
print(res1)
|
|
|
|
print("torch.mlir")
|
|
|
|
print(res2)
|