mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] example of a sparse graph convolution (#3363)
parent
6cba93b16e
commit
e80f072ba4
|
@ -684,3 +684,66 @@ def test_sparse_feature_scaling():
|
|||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_gcn
|
||||
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>,
|
||||
# CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> {
|
||||
# CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_4_torch.float32> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
|
||||
# CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
|
||||
# CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32>
|
||||
# CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
# CHECK: %[[ONE:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956],
|
||||
# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580],
|
||||
# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676],
|
||||
# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}}, grad_fn=<{{.*}}>)
|
||||
# CHECK: torch.mlir
|
||||
# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ]
|
||||
# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ]
|
||||
# CHECK: [ {{1.7397[0-9]* -0.12080[0-9]* 1.4058[0-9]* 2.1676[0-9]*}} ]
|
||||
# CHECK: [ {{1.8583[0-9]* 0.71777[0-9]* 1.3857[0-9]* 1.4672[0-9]*}} ]{{\]}}
|
||||
#
|
||||
def test_sparse_gcn():
|
||||
class GraphConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(GraphConv, self).__init__()
|
||||
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
|
||||
nn.init.xavier_normal_(self.kernel)
|
||||
self.bias = nn.Parameter(torch.Tensor(output_dim))
|
||||
nn.init.ones_(self.bias)
|
||||
|
||||
def forward(self, inp, adj_mat):
|
||||
# Input matrix times weight matrix.
|
||||
support = torch.mm(inp, self.kernel)
|
||||
# Sparse adjacency matrix times support matrix.
|
||||
output = torch.spmm(adj_mat, support)
|
||||
# Add bias.
|
||||
output = output + self.bias
|
||||
return output
|
||||
|
||||
net = GraphConv(4, 4)
|
||||
|
||||
# Get a random (but reproducible) matrices.
|
||||
torch.manual_seed(0)
|
||||
inp = torch.rand(4, 4)
|
||||
adj_mat = torch.rand(4, 4).to_sparse()
|
||||
m = export_and_import(net, inp, adj_mat)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(inp, adj_mat)
|
||||
res2 = sparse_jit(net, inp, adj_mat)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
|
Loading…
Reference in New Issue