# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s from typing import Any, Callable, Optional, Tuple, Dict import torch import torch.nn as nn import numpy as np from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_importer import FxImporter from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.compiler_utils import run_pipeline_with_repro_report from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( RefBackendLinalgOnTensorsBackend, ) def export_and_import(f, *args, **kwargs): """A FX graph importer, stripped down to essentials.""" context = ir.Context() torch_d.register_dialect(context) fx_importer = FxImporter(context=context) prog = torch.export.export(f, args, kwargs) decomposition_table = get_decomposition_table() if decomposition_table: prog = prog.run_decompositions(decomposition_table) fx_importer.import_frozen_program(prog) return fx_importer.module def sparse_jit(f, *args, **kwargs): """This method compiles and runs the given callable using linalg backend.""" # Import module and lower into Linalg IR. module = export_and_import(f, *args, **kwargs) run_pipeline_with_repro_report( module, ( "builtin.module(" "func.func(torch-decompose-complex-ops)," "torch-backend-to-linalg-on-tensors-backend-pipeline)" ), "Lowering TorchFX IR -> Linalg IR", enable_ir_printing=False, ) # Compile with reference Linalg backend. # TODO: runtime verification ails with 'rank mismatch' on memref.cast backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False) compiled = backend.compile(module) invoker = backend.load(compiled) xargs = [] # Prepare all the named buffer parameters (assume all dense). # All scalar arguments are filtered out since they appear inline. params = dict(f.named_buffers(remove_duplicate=True)) params_flat, params_spec = torch.utils._pytree.tree_flatten(params) for p in params_flat: if len(p.shape) > 0: xargs.append(p.numpy()) # Prepare input parameters. Sparse input tensors are split into # their composite tensors. All PyTorch tensors are converted # to their backing numpy arrays. Note that the output consists # of numpy arrays as well, which can trivially be reconstructed # into PyTorch tensors (dense and sparse). for a in args: if a.layout is torch.sparse_coo: # Construct the additional position array required by MLIR with data # array([0, nnz]). The COO format always uses int64 indices. xargs.append(np.array([0, a._nnz()], dtype=np.int64)) # Transform a tensor into ndim x tensor to conform # to the MLIR SoA COO representation. for idx in a._indices(): xargs.append(idx.numpy()) xargs.append(a._values().numpy()) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.crow_indices().numpy()) xargs.append(a.col_indices().numpy()) xargs.append(a.values().numpy()) elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: xargs.append(a.ccol_indices().numpy()) xargs.append(a.row_indices().numpy()) xargs.append(a.values().numpy()) else: xargs.append(a.numpy()) # Invoke. return invoker.main(*xargs) def run(f): # Prompt test name and torch version (for debugging). print(f"{f.__name__} ({torch.__version__})") print("-" * len(f.__name__)) f() print() @run # # CHECK-LABEL: test_sparse_id # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> { # CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9], # CHECK: [ 0, 1, 10, 19]{{\]}}), # CHECK: values=tensor([-1000., -1., 1., 1000.]), # CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo) # CHECK: torch.mlir # CHECK: [0 4] # CHECK: [0 1 2 9] # CHECK: [ 0 1 10 19] # CHECK: [-1000. -1. 1. 1000.] # def test_sparse_id(): class IdNet(torch.nn.Module): def __init__(self): super(IdNet, self).__init__() def forward(self, x): return x net = IdNet() idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]]) val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64) sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20]) m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse") print(res1) print("torch.mlir") print(res2[0]) print(res2[1]) print(res2[2]) print(res2[3]) @run # # CHECK-LABEL: test_sparse_sum # CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> { # CHECK: %[[N:.*]] = torch.constant.none # CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32> # CHECK: return %[[R]] : !torch.vtensor<[],f32> # CHECK: } # # CHECK: torch.sparse = tensor(4096.) # CHECK: torch.mlir = 4096.0 # def test_sparse_sum(): class SumNet(torch.nn.Module): def __init__(self): super(SumNet, self).__init__() def forward(self, x): return x.sum() net = SumNet() dense_input = torch.ones(64, 64) sparse_input = dense_input.to_sparse_csr() m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse =", res1) print("torch.mlir =", res2) @run # # CHECK-LABEL: test_sparse_SpMV # CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>, # CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> { # CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> # CHECK: return %[[R]] : !torch.vtensor<[10],f32> # CHECK: } # # CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]) # CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.] # def test_sparse_SpMV(): class SpMVNet(torch.nn.Module): def __init__(self): super(SpMVNet, self).__init__() def forward(self, x, v): return torch.mv(x, v) net = SpMVNet() dense_vector = torch.arange(1, 11, dtype=torch.float32) dense_input = torch.ones(10, 10) sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) m = export_and_import(net, sparse_input, dense_vector) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input, dense_vector) res2 = sparse_jit(net, sparse_input, dense_vector) print("torch.sparse =", res1) print("torch.mlir =", res2) @run # # CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, # CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { # CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> # CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> # CHECK: } ## # CHECK: torch.sparse # CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], # CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], # CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) # CHECK: torch.mlir # CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] # CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] # CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} # def test_sparse_SpMM(): class MatMulNet(torch.nn.Module): def __init__(self): super(MatMulNet, self).__init__() def forward(self, x, y): return torch.matmul(x, y) net = MatMulNet() dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() m = export_and_import(net, sparse_input, dense_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input, dense_input) res2 = sparse_jit(net, sparse_input, dense_input) print("torch.sparse") print(res1) print("torch.mlir") print(res2) @run # # CHECK-LABEL: test_sparse_eltwise # CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> { # CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> # CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> # CHECK: } # CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> { # CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> # CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]), # CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]), # CHECK: values=tensor({{\[}}[ -1., -2.], # CHECK: [ -3., -4.], # CHECK: [ -5., -6.], # CHECK: [ -7., -8.], # CHECK: [ -9., -10.], # CHECK: [-11., -12.], # CHECK: [-13., -14.], # CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8, # CHECK: layout=torch.sparse_csr) # CHECK: torch.mlir # CHECK: [0 2 4 6 8] # CHECK: [0 1 0 1 0 1 0 1] # CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14. # CHECK: -15. -16.] # CHECK: torch.mlir.batch # def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): super(EltNet, self).__init__() def forward(self, x): return -x net = EltNet() dense_input = torch.reshape( torch.arange(1, 17, dtype=torch.float32), shape=(4, 2, 2) ) # This yields a plain CSR with dense **sub**tensor sparse_input = dense_input.to_sparse_csr(dense_dim=1) m = export_and_import(net, sparse_input) print(m) # This yields a **batched** CSR. batch_input = dense_input.to_sparse_csr(dense_dim=0) m = export_and_import(net, batch_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) # TODO: make this work in MLIR # res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) print("torch.mlir") print(res2[0]) print(res2[1]) print(res2[2]) print("torch.mlir.batch") @run # # CHECK-LABEL: test_sparse_coo3 # CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#[[$COO3]]>) -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> { # CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> # CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 1, 4, 9, 9], # CHECK: [ 0, 1, 1, 5, 19, 19], # CHECK: [ 0, 1, 3, 6, 28, 29]{{\]}}), # CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), # CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) # CHECK: torch.mlir # CHECK: [0 6] # CHECK: [0 1 1 4 9 9] # CHECK: [ 0 1 1 5 19 19] # CHECK: [ 0 1 3 6 28 29] # CHECK: [ 0. 0. 1. 2. 3. 1000.] # def test_sparse_coo3(): class COO3Net(torch.nn.Module): def __init__(self): super(COO3Net, self).__init__() self.relu = nn.ReLU() def forward(self, x): return self.relu(x) net = COO3Net() # Direct 3-dim COO construction. idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]]) val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64) sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30]) m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse") print(res1) print("torch.mlir") print(res2[0]) print(res2[1]) print(res2[2]) print(res2[3]) print(res2[4]) @run # # CHECK-LABEL: test_sparse_activation # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> { # CHECK: %[[N1:.*]] = torch.constant.none # CHECK: %[[N2:.*]] = torch.constant.none # CHECK: %[[N3:.*]] = torch.constant.none # CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1], # CHECK: [0, 0, 1, 1, 0, 0, 1, 1], # CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}), # CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]), # CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo) # CHECK: torch.mlir # CHECK: [0 8] # CHECK: [0 0 0 0 1 1 1 1] # CHECK: [0 0 1 1 0 0 1 1] # CHECK: [0 1 0 1 0 1 0 1] # CHECK: [1. 1. 1. 1. 1. 1. 1. 1.] # def test_sparse_activation(): class SparseActivationCOO(torch.nn.Module): def forward(self, x): return x.to_sparse() net = SparseActivationCOO() x = torch.ones(2, 2, 2) m = export_and_import(net, x) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) res2 = sparse_jit(net, x) print("torch.sparse") print(res1) print("torch.mlir") print(res2[0]) print(res2[1]) print(res2[2]) print(res2[3]) print(res2[4]) @run # # CHECK-LABEL: test_sparse_network # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { # ... lots of IR ... # CHECK-COUNT-15: torch.aten.mul.Tensor # ... lots of IR ... # CHECK: } # # CHECK: torch.sparse # CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) # CHECK: torch.mlir # CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] # def test_sparse_network(): def spike(input): return (input >= 0).float() def sqSum(input): return (input * input).sum() class LIF(nn.Module): def __init__(self): super(LIF, self).__init__() self.thresh = 1.0 self.decay = 0.5 self.act = spike def forward(self, X): """A filter that yields a binary-valued sparse tensor.""" mem = 0 spike_pot = [] T = X.size(-1) for t in range(T): mem = mem * self.decay + X[..., t] spike = self.act(mem - self.thresh) spike = spike.to_sparse().to_dense() # prop hack mem = mem * (1.0 - spike) spike_pot.append(spike) spike_pot = torch.stack(spike_pot, dim=-1) return spike_pot class tdLayer(nn.Module): def __init__(self, layer): super(tdLayer, self).__init__() self.layer = layer def forward(self, X): T = X.size(-1) out = [] for t in range(T): m = self.layer(X[..., t]) out.append(m) out = torch.stack(out, dim=-1) return out class Block(nn.Module): def __init__(self): super(Block, self).__init__() self.spike = LIF() self.layer = tdLayer(sqSum) def forward(self, X): out = self.spike(X) out = self.layer(out) return out net = Block() # Get a random (but reproducible) input, so that a # general sparse tensor appears after LIF. torch.manual_seed(0) x = torch.rand(2, 3, 8, 8) m = export_and_import(net, x) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) res2 = sparse_jit(net, x) print("torch.sparse") print(res1) print("torch.mlir") print(res2) @run # # CHECK-LABEL: test_sparse_feature_scaling # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { # ... more IR ... # CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" # CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] # CHECK return %[[R]] : !torch.vtensor<[4,4],f32> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], # CHECK: [0.1321, 0.2724, 0.2105, 0.3851], # CHECK: [0.2478, 0.3439, 0.1898, 0.2185], # CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) # # TODO: first row looks suspect... # # CHECK: torch.mlir # CHECK: {{\[}}[0. 0. 0. 0. ] # CHECK: [0.13205223 0.27236593 0.21051763 0.38506418] # CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] # CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} # def test_sparse_feature_scaling(): class Scale(nn.Module): def forward(self, F): sum_vector = torch.sum(F, dim=1) reciprocal_vector = 1 / sum_vector reciprocal_vector[reciprocal_vector == float("inf")] = 0 scaling_diagonal = torch.diag(reciprocal_vector).to_sparse() return scaling_diagonal @ F net = Scale() # Get a random (but reproducible) features input. torch.manual_seed(0) f = torch.rand(4, 4) m = export_and_import(net, f) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(f) res2 = sparse_jit(net, f) print("torch.sparse") print(res1) print("torch.mlir") print(res2) @run # # CHECK-LABEL: test_sparse_gcn # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>, # CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> { # CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32> # CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> # CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> # CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource : tensor<4xf32>) : !torch.vtensor<[4],f32> # CHECK: %[[ONE:.*]] = torch.constant.int 1 # CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32> # CHECK return %[[R]] : !torch.vtensor<[4,4],f32> # CHECK: } # # CHECK: torch.sparse # CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778], # CHECK: [5.7502, 5.7502, 5.7502, 5.7502], # CHECK: [4.6980, 4.6980, 4.6980, 4.6980], # CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}}) # CHECK: torch.mlir # CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ] # CHECK: [5.7501717 5.7501717 5.7501717 5.7501717] # CHECK: [4.697952 4.697952 4.697952 4.697952 ] # CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}} # def test_sparse_gcn(): class GraphConv(nn.Module): def __init__(self, input_dim, output_dim): super(GraphConv, self).__init__() self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim)) nn.init.ones_(self.kernel) self.bias = nn.Parameter(torch.Tensor(output_dim)) nn.init.ones_(self.bias) def forward(self, inp, adj_mat): # Input matrix times weight matrix. support = torch.mm(inp, self.kernel) # Sparse adjacency matrix times support matrix. output = torch.spmm(adj_mat, support) # Add bias. output = output + self.bias return output net = GraphConv(4, 4) # Get a random (but reproducible) matrices. torch.manual_seed(0) inp = torch.rand(4, 4) adj_mat = torch.rand(4, 4).to_sparse() m = export_and_import(net, inp, adj_mat) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Set to inference mode to avoid autograd component in result. with torch.no_grad(): res1 = net(inp, adj_mat) res2 = sparse_jit(net, inp, adj_mat) print("torch.sparse") print(res1) print("torch.mlir") print(res2)