diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index e4e95a9a8..bfe404c92 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -630,3 +630,47 @@ def test_sparse_network(): print(res1) print("torch.mlir") print(res2) + + +@run +# +# CHECK-LABEL: test_sparse_feature_scaling +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# ... more IR ... +# CHECK: %[[D:.*]] = torch.operator "torch.aten._to_sparse" +# CHECK: %[[R:.*]] = torch.aten.mm %[[D]], %[[A]] +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], +# CHECK: [0.1321, 0.2724, 0.2105, 0.3851], +# CHECK: [0.2478, 0.3439, 0.1898, 0.2185], +# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# CHECK: torch.mlir +# +def test_sparse_feature_scaling(): + class Scale(nn.Module): + def forward(self, F): + sum_vector = torch.sum(F, dim=1) + reciprocal_vector = 1 / sum_vector + reciprocal_vector[reciprocal_vector == float("inf")] = 0 + scaling_diagonal = torch.diag(reciprocal_vector).to_sparse() + return scaling_diagonal @ F + + net = Scale() + + # Get a random (but reproducible) features input. + torch.manual_seed(0) + f = torch.rand(4, 4) + m = export_and_import(net, f) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(f) + # TODO: make this work + # res2 = sparse_jit(net, f) + print("torch.sparse") + print(res1) + print("torch.mlir")