mirror of https://github.com/llvm/torch-mlir
Rewrite get_registered_ops.cpp to python version
parent
b08d08682f
commit
bda2214d49
|
@ -0,0 +1,60 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""Listing of the JIT operator registry, for use in generating the `torch` dialect.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
import pybind11
|
||||
|
||||
def get_registered_ops():
|
||||
results = []
|
||||
|
||||
# Walk the JIT operator registry to find all the ops that we might need
|
||||
# for introspection / ODS generation.
|
||||
# This registry contains a superset of the ops available to the dispatcher,
|
||||
# since the JIT has its own dispatch mechanism that it uses to implement
|
||||
# "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||
# as `aten::__is__`.
|
||||
for schema in torch._C._jit_get_all_schemas():
|
||||
record = {}
|
||||
|
||||
record["name"] = schema.name
|
||||
record["overload_name"] = schema.overload_name
|
||||
record["is_mutable"] = schema.is_mutable
|
||||
|
||||
arguments = []
|
||||
returns = []
|
||||
|
||||
def add_argument(container, arg):
|
||||
arg_record = {
|
||||
"name": arg.name,
|
||||
"type": arg.type.annotation_str,
|
||||
"kwarg_only" : arg.kwarg_only,
|
||||
"is_out": arg.is_out,
|
||||
}
|
||||
if arg.default_value:
|
||||
arg_record["default_value"] = arg.default_value
|
||||
if arg.alias_info:
|
||||
alias_info = {
|
||||
"is_write": arg.alias_info.is_write,
|
||||
"before_set": [str(symbol) for symbol in arg.alias_info.before_set],
|
||||
"after_set": [str(symbol) for symbol in arg.alias_info.after_set],
|
||||
}
|
||||
arg_record["alias_info"] = alias_info
|
||||
|
||||
container.append(arg_record)
|
||||
|
||||
for argument in schema.arguments:
|
||||
add_argument(arguments, argument)
|
||||
for return_arg in schema.returns:
|
||||
add_argument(returns, return_arg)
|
||||
|
||||
record["arguments"] = arguments
|
||||
record["returns"] = returns
|
||||
results.append(record)
|
||||
|
||||
return results
|
|
@ -14,9 +14,10 @@ import difflib
|
|||
from .utils import TextEmitter
|
||||
|
||||
# Note that this utility exists only in the c-extension.
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import (
|
||||
get_registered_ops,
|
||||
) # pytype: disable=import-error
|
||||
# from torch_mlir._mlir_libs._jit_ir_importer import (
|
||||
# get_registered_ops,
|
||||
# ) # pytype: disable=import-error
|
||||
from .get_registered_ops import get_registered_ops
|
||||
|
||||
|
||||
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
||||
|
@ -28,7 +29,7 @@ def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
|||
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||
default = ""
|
||||
if "default_debug" in arg:
|
||||
if "List" in arg["pytype"]:
|
||||
if "List" in arg["type"]:
|
||||
# TorchScript doesn't allow lists as default parameters due
|
||||
# to the weird Python semantics of mutable default
|
||||
# arguments. So munge them into tuples, which work
|
||||
|
@ -44,7 +45,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
|||
default_debug = "()"
|
||||
else:
|
||||
default_debug = default_list.replace("[", "(").replace("]", ",)")
|
||||
elif arg["pytype"] == "str":
|
||||
elif arg["type"] == "str":
|
||||
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
||||
else:
|
||||
default_debug = arg["default_debug"]
|
||||
|
@ -119,9 +120,9 @@ class JitOperator:
|
|||
self.namespace = namespace
|
||||
self.unqualified_name = unqualified_name
|
||||
self.overload_name = op_info["name"][1]
|
||||
self.is_c10_op = op_info["is_c10_op"]
|
||||
self.is_vararg = op_info["is_vararg"]
|
||||
self.is_varret = op_info["is_varret"]
|
||||
# self.is_c10_op = op_info["is_c10_op"]
|
||||
# self.is_vararg = op_info["is_vararg"]
|
||||
# self.is_varret = op_info["is_varret"]
|
||||
self.is_mutable = op_info["is_mutable"]
|
||||
self.arguments = op_info["arguments"]
|
||||
self.returns = op_info["returns"]
|
||||
|
@ -150,11 +151,14 @@ class JitOperator:
|
|||
concern for them)
|
||||
"""
|
||||
overload = "" if not self.overload_name else f".{self.overload_name}"
|
||||
if self.is_vararg:
|
||||
# Check if any argument is a variable argument
|
||||
if any(arg["type"] == "..." for arg in self.arguments):
|
||||
arg_str = "..."
|
||||
else:
|
||||
arg_str = ", ".join(arg["type"] for arg in self.arguments)
|
||||
if self.is_varret:
|
||||
|
||||
# Check if any return type is a variable return type
|
||||
if any(ret["type"] == "..." for ret in self.returns):
|
||||
ret_str = "..."
|
||||
else:
|
||||
ret_str = ", ".join(ret["type"] for ret in self.returns)
|
||||
|
@ -233,13 +237,13 @@ class JitOperator:
|
|||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||
pytype = _pytype_to_shape_fn_pytype(arg["type"])
|
||||
default = _get_default_value(arg)
|
||||
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||
return f"{parameter_name}: {pytype}{default}"
|
||||
|
||||
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
return _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||
return _pytype_to_shape_fn_pytype(arg["type"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"shape", parameter_decl_builder, ret_decl_builder
|
||||
|
@ -255,10 +259,10 @@ class JitOperator:
|
|||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||
pytype = _pytype_to_dtype_fn_pytype(arg["type"])
|
||||
default = _get_default_value(arg)
|
||||
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||
if "Tensor" in arg["pytype"]:
|
||||
if "Tensor" in arg["type"]:
|
||||
return f"{parameter_name}_rank_dtype: {pytype}{default}"
|
||||
return f"{parameter_name}: {pytype}{default}"
|
||||
|
||||
|
@ -267,9 +271,9 @@ class JitOperator:
|
|||
# results of type `number`. Here we handle this case because
|
||||
# `_pytype_to_dtype_fn_pytype` will replace `number` with
|
||||
# `Union[int, float]`.
|
||||
if arg["pytype"] in ["number", "Tensor"]:
|
||||
if arg["type"] in ["number", "Tensor"]:
|
||||
return "int"
|
||||
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||
return _pytype_to_dtype_fn_pytype(arg["type"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"dtype", parameter_decl_builder, ret_decl_builder
|
||||
|
@ -285,13 +289,13 @@ class JitOperator:
|
|||
"""
|
||||
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||
pytype = _pytype_to_decomposition_fn_pytype(arg["type"])
|
||||
default = _get_default_value(arg)
|
||||
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||
return f"{parameter_name}: {pytype}{default}"
|
||||
|
||||
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||
return _pytype_to_decomposition_fn_pytype(arg["type"])
|
||||
|
||||
return self._get_function_signature(
|
||||
"decomposition", parameter_decl_builder, ret_decl_builder
|
||||
|
@ -332,9 +336,9 @@ class JitOperator:
|
|||
p(f"namespace = {self.namespace}")
|
||||
p(f"unqualified_name = {self.unqualified_name}")
|
||||
p(f"overload_name = {self.overload_name}")
|
||||
p(f"is_c10_op = {self.is_c10_op}")
|
||||
p(f"is_vararg = {self.is_vararg}")
|
||||
p(f"is_varret = {self.is_varret}")
|
||||
# p(f"is_c10_op = {self.is_c10_op}")
|
||||
# p(f"is_vararg = {self.is_vararg}")
|
||||
# p(f"is_varret = {self.is_varret}")
|
||||
p(f"is_mutable = {self.is_mutable}")
|
||||
if any(ret["type"] == "Tensor" for ret in self.returns):
|
||||
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
||||
|
|
Loading…
Reference in New Issue