[torch-mlir][sparse] inference mode for sparse GCN test (#3369)

pull/3133/head
Aart Bik 2024-05-20 19:52:16 -07:00 committed by GitHub
parent 297c270980
commit c0e7d2667d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 7 deletions

View File

@ -706,7 +706,7 @@ def test_sparse_feature_scaling():
# CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956], # CHECK: tensor({{\[}}[ 1.8340, 0.1386, 1.4181, 1.9956],
# CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580], # CHECK: [ 2.2926, 0.0797, 1.6182, 2.1580],
# CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676], # CHECK: [ 1.7397, -0.1208, 1.4059, 2.1676],
# CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}}, grad_fn=<{{.*}}>) # CHECK: [ 1.8583, 0.7178, 1.3857, 1.4673]{{\]}})
# CHECK: torch.mlir # CHECK: torch.mlir
# CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ] # CHECK: {{\[}}[ {{1.8339[0-9]* 0.13862[0-9]* 1.4181[0-9]* 1.9955[0-9]*}} ]
# CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ] # CHECK: [ {{2.2926[0-9]* 0.07968[0-9]* 1.6181[0-9]* 2.1579[0-9]*}} ]
@ -741,6 +741,8 @@ def test_sparse_gcn():
print(m) print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# Set to inference mode to avoid autograd component in result.
with torch.no_grad():
res1 = net(inp, adj_mat) res1 = net(inp, adj_mat)
res2 = sparse_jit(net, inp, adj_mat) res2 = sparse_jit(net, inp, adj_mat)
print("torch.sparse") print("torch.sparse")