mirror of https://github.com/llvm/torch-mlir
[sparse] match fx node using target name instead of variables name (#3315)
parent
64b59c7fc3
commit
2c22087cab
|
@ -120,16 +120,18 @@ def sparse_export(
|
|||
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||
k = k + 1
|
||||
elif node.op == "call_function":
|
||||
# TODO: use upstream _opname implementation when available
|
||||
opname = node.target._schema.name.split("::")[1]
|
||||
# Zero preserving elt-wise unary op.
|
||||
if node.name in {"abs", "neg", "relu", "sin"}:
|
||||
if opname in {"abs", "neg", "relu", "sin"}:
|
||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif node.name == "_to_sparse":
|
||||
elif opname == "_to_sparse":
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
# TODO: Uncomment this to hack sparsity into the network.
|
||||
# elif node.name == "_to_dense":
|
||||
# elif opname == "_to_dense":
|
||||
# # hack (assumes we never really want the to_dense for now)
|
||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
return prog
|
||||
|
|
Loading…
Reference in New Issue