[torch-mlir][sparse] replace xavier with ones initialization (#3374)

ensures stability of results between different set ups
pull/3376/head
Aart Bik 2024-05-21 17:12:55 -07:00 committed by GitHub
parent fcf48872b3
commit 560ca24771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 9 deletions

View File

@ -703,22 +703,22 @@ def test_sparse_feature_scaling():
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956],
# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580],
# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676],
# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}})
# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778],
# CHECK: [5.7502, 5.7502, 5.7502, 5.7502],
# CHECK: [4.6980, 4.6980, 4.6980, 4.6980],
# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}})
# CHECK: torch.mlir
# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ]
# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ]
# CHECK: [ {{1.7397[0-9]* -0.12080[0-9]* 1.4058[0-9]* 2.1676[0-9]*}} ]
# CHECK: [ {{1.8583[0-9]* 0.71777[0-9]* 1.3857[0-9]* 1.4672[0-9]*}} ]{{\]}}
# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
#
def test_sparse_gcn():
class GraphConv(nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConv, self).__init__()
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
nn.init.xavier_normal_(self.kernel)
nn.init.ones_(self.kernel)
self.bias = nn.Parameter(torch.Tensor(output_dim))
nn.init.ones_(self.bias)