[torch-mlir][sparse] example of a sparse graph convolution (#3363)

pull/3268/merge
Aart Bik 2024-05-17 15:43:50 -07:00 committed by GitHub
parent 6cba93b16e
commit e80f072ba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 63 additions and 0 deletions

View File

@ -684,3 +684,66 @@ def test_sparse_feature_scaling():
print("torch.sparse")
print(res1)
print("torch.mlir")
@run
#
# CHECK-LABEL: test_sparse_gcn
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>,
# CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> {
# CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_4_torch.float32> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
# CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
# CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
# CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
# CHECK: %[[ONE:.*]] = torch.constant.int 1
# CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32>
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956],
# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580],
# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676],
# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}}, grad_fn=<{{.*}}>)
# CHECK: torch.mlir
# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ]
# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ]
# CHECK: [ {{1.7397[0-9]* -0.12080[0-9]* 1.4058[0-9]* 2.1676[0-9]*}} ]
# CHECK: [ {{1.8583[0-9]* 0.71777[0-9]* 1.3857[0-9]* 1.4672[0-9]*}} ]{{\]}}
#
def test_sparse_gcn():
class GraphConv(nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConv, self).__init__()
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
nn.init.xavier_normal_(self.kernel)
self.bias = nn.Parameter(torch.Tensor(output_dim))
nn.init.ones_(self.bias)
def forward(self, inp, adj_mat):
# Input matrix times weight matrix.
support = torch.mm(inp, self.kernel)
# Sparse adjacency matrix times support matrix.
output = torch.spmm(adj_mat, support)
# Add bias.
output = output + self.bias
return output
net = GraphConv(4, 4)
# Get a random (but reproducible) matrices.
torch.manual_seed(0)
inp = torch.rand(4, 4)
adj_mat = torch.rand(4, 4).to_sparse()
m = export_and_import(net, inp, adj_mat)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(inp, adj_mat)
res2 = sparse_jit(net, inp, adj_mat)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)