mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] sparse diagonal feature scaling test (#3344)
parent
8e74d64e8f
commit
44fa6c3afd
|
@ -630,3 +630,47 @@ def test_sparse_network():
|
|||
print(res1)
|
||||
print("torch.mlir")
|
||||
print(res2)
|
||||
|
||||
|
||||
@run
|
||||
#
|
||||
# CHECK-LABEL: test_sparse_feature_scaling
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
|
||||
# ... more IR ...
|
||||
# CHECK: %[[D:.*]] = torch.operator "torch.aten._to_sparse"
|
||||
# CHECK: %[[R:.*]] = torch.aten.mm %[[D]], %[[A]]
|
||||
# CHECK return %[[R]] : !torch.vtensor<[4,4],f32>
|
||||
# CHECK: }
|
||||
#
|
||||
# CHECK: torch.sparse
|
||||
# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889],
|
||||
# CHECK: [0.1321, 0.2724, 0.2105, 0.3851],
|
||||
# CHECK: [0.2478, 0.3439, 0.1898, 0.2185],
|
||||
# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}})
|
||||
# CHECK: torch.mlir
|
||||
#
|
||||
def test_sparse_feature_scaling():
|
||||
class Scale(nn.Module):
|
||||
def forward(self, F):
|
||||
sum_vector = torch.sum(F, dim=1)
|
||||
reciprocal_vector = 1 / sum_vector
|
||||
reciprocal_vector[reciprocal_vector == float("inf")] = 0
|
||||
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
|
||||
return scaling_diagonal @ F
|
||||
|
||||
net = Scale()
|
||||
|
||||
# Get a random (but reproducible) features input.
|
||||
torch.manual_seed(0)
|
||||
f = torch.rand(4, 4)
|
||||
m = export_and_import(net, f)
|
||||
print(m)
|
||||
|
||||
# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
|
||||
res1 = net(f)
|
||||
# TODO: make this work
|
||||
# res2 = sparse_jit(net, f)
|
||||
print("torch.sparse")
|
||||
print(res1)
|
||||
print("torch.mlir")
|
||||
|
|
Loading…
Reference in New Issue