mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] replace xavier with ones initialization (#3374)
ensures stability of results between different set upspull/3376/head
parent
fcf48872b3
commit
560ca24771
|
@ -703,22 +703,22 @@ def test_sparse_feature_scaling():
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
#
|
#
|
||||||
# CHECK: torch.sparse
|
# CHECK: torch.sparse
|
||||||
# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956],
|
# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778],
|
||||||
# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580],
|
# CHECK: [5.7502, 5.7502, 5.7502, 5.7502],
|
||||||
# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676],
|
# CHECK: [4.6980, 4.6980, 4.6980, 4.6980],
|
||||||
# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}})
|
# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}})
|
||||||
# CHECK: torch.mlir
|
# CHECK: torch.mlir
|
||||||
# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ]
|
# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
|
||||||
# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ]
|
# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
|
||||||
# CHECK: [ {{1.7397[0-9]* -0.12080[0-9]* 1.4058[0-9]* 2.1676[0-9]*}} ]
|
# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
|
||||||
# CHECK: [ {{1.8583[0-9]* 0.71777[0-9]* 1.3857[0-9]* 1.4672[0-9]*}} ]{{\]}}
|
# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
|
||||||
#
|
#
|
||||||
def test_sparse_gcn():
|
def test_sparse_gcn():
|
||||||
class GraphConv(nn.Module):
|
class GraphConv(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim):
|
def __init__(self, input_dim, output_dim):
|
||||||
super(GraphConv, self).__init__()
|
super(GraphConv, self).__init__()
|
||||||
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
|
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
|
||||||
nn.init.xavier_normal_(self.kernel)
|
nn.init.ones_(self.kernel)
|
||||||
self.bias = nn.Parameter(torch.Tensor(output_dim))
|
self.bias = nn.Parameter(torch.Tensor(output_dim))
|
||||||
nn.init.ones_(self.bias)
|
nn.init.ones_(self.bias)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue