[sparse] match fx node using target name instead of variables name (#3315)

pull/3321/head
Peiming Liu 2024-05-09 12:34:14 -07:00 committed by GitHub
parent 64b59c7fc3
commit 2c22087cab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 3 deletions

View File

@ -120,16 +120,18 @@ def sparse_export(
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if node.name in {"abs", "neg", "relu", "sin"}:
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif node.name == "_to_sparse":
elif opname == "_to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif node.name == "_to_dense":
# elif opname == "_to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
return prog