mirror of https://github.com/llvm/torch-mlir
[sparse] match fx node using target name instead of variables name (#3315)
parent
64b59c7fc3
commit
2c22087cab
|
@ -120,16 +120,18 @@ def sparse_export(
|
||||||
node.meta["sparsity"] = sparse_metadata(args[k])
|
node.meta["sparsity"] = sparse_metadata(args[k])
|
||||||
k = k + 1
|
k = k + 1
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
|
# TODO: use upstream _opname implementation when available
|
||||||
|
opname = node.target._schema.name.split("::")[1]
|
||||||
# Zero preserving elt-wise unary op.
|
# Zero preserving elt-wise unary op.
|
||||||
if node.name in {"abs", "neg", "relu", "sin"}:
|
if opname in {"abs", "neg", "relu", "sin"}:
|
||||||
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||||
elif node.name == "_to_sparse":
|
elif opname == "_to_sparse":
|
||||||
dim = len(node.meta.get("val").shape)
|
dim = len(node.meta.get("val").shape)
|
||||||
node.meta["sparsity"] = SparsityMeta(
|
node.meta["sparsity"] = SparsityMeta(
|
||||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||||
)
|
)
|
||||||
# TODO: Uncomment this to hack sparsity into the network.
|
# TODO: Uncomment this to hack sparsity into the network.
|
||||||
# elif node.name == "_to_dense":
|
# elif opname == "_to_dense":
|
||||||
# # hack (assumes we never really want the to_dense for now)
|
# # hack (assumes we never really want the to_dense for now)
|
||||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||||
return prog
|
return prog
|
||||||
|
|
Loading…
Reference in New Issue