[torch-mlir][sparse] sparse diagonal feature scaling test (#3344)

pull/3346/head
Aart Bik 2024-05-14 12:13:54 -07:00 committed by GitHub
parent 8e74d64e8f
commit 44fa6c3afd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 0 deletions

View File

@ -630,3 +630,47 @@ def test_sparse_network():
print(res1)
print("torch.mlir")
print(res2)
@run
#
# CHECK-LABEL: test_sparse_feature_scaling
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
# ... more IR ...
# CHECK: %[[D:.*]] = torch.operator "torch.aten._to_sparse"
# CHECK: %[[R:.*]] = torch.aten.mm %[[D]], %[[A]]
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
# CHECK: torch.mlir
#
def test_sparse_feature_scaling():
class Scale(nn.Module):
def forward(self, F):
sum_vector = torch.sum(F, dim=1)
reciprocal_vector = 1 / sum_vector
reciprocal_vector[reciprocal_vector == float("inf")] = 0
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
return scaling_diagonal @ F
net = Scale()
# Get a random (but reproducible) features input.
torch.manual_seed(0)
f = torch.rand(4, 4)
m = export_and_import(net, f)
print(m)
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(f)
# TODO: make this work
# res2 = sparse_jit(net, f)
print("torch.sparse")
print(res1)
print("torch.mlir")