From e80f072ba43c89af300190a2f9b3d63f9e36c84d Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 17 May 2024 15:43:50 -0700 Subject: [PATCH] [torch-mlir][sparse] example of a sparse graph convolution (#3363) --- test/python/fx_importer/sparse_test.py | 63 ++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 87d2e3d96..30f1f21b4 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -684,3 +684,66 @@ def test_sparse_feature_scaling(): print("torch.sparse") print(res1) print("torch.mlir") + + +@run +# +# CHECK-LABEL: test_sparse_gcn +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>, +# CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> { +# CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32> +# CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> +# CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> +# CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource : tensor<4xf32>) : !torch.vtensor<[4],f32> +# CHECK: %[[ONE:.*]] = torch.constant.int 1 +# CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32> +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956], +# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580], +# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676], +# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}}, grad_fn=<{{.*}}>) +# CHECK: torch.mlir +# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ] +# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ] +# CHECK: [ {{1.7397[0-9]* -0.12080[0-9]* 1.4058[0-9]* 2.1676[0-9]*}} ] +# CHECK: [ {{1.8583[0-9]* 0.71777[0-9]* 1.3857[0-9]* 1.4672[0-9]*}} ]{{\]}} +# +def test_sparse_gcn(): + class GraphConv(nn.Module): + def __init__(self, input_dim, output_dim): + super(GraphConv, self).__init__() + self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim)) + nn.init.xavier_normal_(self.kernel) + self.bias = nn.Parameter(torch.Tensor(output_dim)) + nn.init.ones_(self.bias) + + def forward(self, inp, adj_mat): + # Input matrix times weight matrix. + support = torch.mm(inp, self.kernel) + # Sparse adjacency matrix times support matrix. + output = torch.spmm(adj_mat, support) + # Add bias. + output = output + self.bias + return output + + net = GraphConv(4, 4) + + # Get a random (but reproducible) matrices. + torch.manual_seed(0) + inp = torch.rand(4, 4) + adj_mat = torch.rand(4, 4).to_sparse() + m = export_and_import(net, inp, adj_mat) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(inp, adj_mat) + res2 = sparse_jit(net, inp, adj_mat) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2)