Rewrite get_registered_ops.cpp to python version

pull/3780/head
AmosLewis 2024-10-10 08:53:46 -07:00
parent b08d08682f
commit bda2214d49
2 changed files with 85 additions and 21 deletions

View File

@ -0,0 +1,60 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""Listing of the JIT operator registry, for use in generating the `torch` dialect.
"""
import torch
import torch._C
import pybind11
def get_registered_ops():
results = []
# Walk the JIT operator registry to find all the ops that we might need
# for introspection / ODS generation.
# This registry contains a superset of the ops available to the dispatcher,
# since the JIT has its own dispatch mechanism that it uses to implement
# "prim" ops and a handful of "aten" ops that are effectively prim ops, such
# as `aten::__is__`.
for schema in torch._C._jit_get_all_schemas():
record = {}
record["name"] = schema.name
record["overload_name"] = schema.overload_name
record["is_mutable"] = schema.is_mutable
arguments = []
returns = []
def add_argument(container, arg):
arg_record = {
"name": arg.name,
"type": arg.type.annotation_str,
"kwarg_only" : arg.kwarg_only,
"is_out": arg.is_out,
}
if arg.default_value:
arg_record["default_value"] = arg.default_value
if arg.alias_info:
alias_info = {
"is_write": arg.alias_info.is_write,
"before_set": [str(symbol) for symbol in arg.alias_info.before_set],
"after_set": [str(symbol) for symbol in arg.alias_info.after_set],
}
arg_record["alias_info"] = alias_info
container.append(arg_record)
for argument in schema.arguments:
add_argument(arguments, argument)
for return_arg in schema.returns:
add_argument(returns, return_arg)
record["arguments"] = arguments
record["returns"] = returns
results.append(record)
return results

View File

@ -14,9 +14,10 @@ import difflib
from .utils import TextEmitter
# Note that this utility exists only in the c-extension.
from torch_mlir._mlir_libs._jit_ir_importer import (
get_registered_ops,
) # pytype: disable=import-error
# from torch_mlir._mlir_libs._jit_ir_importer import (
# get_registered_ops,
# ) # pytype: disable=import-error
from .get_registered_ops import get_registered_ops
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
@ -28,7 +29,7 @@ def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
default = ""
if "default_debug" in arg:
if "List" in arg["pytype"]:
if "List" in arg["type"]:
# TorchScript doesn't allow lists as default parameters due
# to the weird Python semantics of mutable default
# arguments. So munge them into tuples, which work
@ -44,7 +45,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
default_debug = "()"
else:
default_debug = default_list.replace("[", "(").replace("]", ",)")
elif arg["pytype"] == "str":
elif arg["type"] == "str":
default_debug = repr(arg["default_debug"]).replace("'", '"')
else:
default_debug = arg["default_debug"]
@ -119,9 +120,9 @@ class JitOperator:
self.namespace = namespace
self.unqualified_name = unqualified_name
self.overload_name = op_info["name"][1]
self.is_c10_op = op_info["is_c10_op"]
self.is_vararg = op_info["is_vararg"]
self.is_varret = op_info["is_varret"]
# self.is_c10_op = op_info["is_c10_op"]
# self.is_vararg = op_info["is_vararg"]
# self.is_varret = op_info["is_varret"]
self.is_mutable = op_info["is_mutable"]
self.arguments = op_info["arguments"]
self.returns = op_info["returns"]
@ -150,11 +151,14 @@ class JitOperator:
concern for them)
"""
overload = "" if not self.overload_name else f".{self.overload_name}"
if self.is_vararg:
# Check if any argument is a variable argument
if any(arg["type"] == "..." for arg in self.arguments):
arg_str = "..."
else:
arg_str = ", ".join(arg["type"] for arg in self.arguments)
if self.is_varret:
# Check if any return type is a variable return type
if any(ret["type"] == "..." for ret in self.returns):
ret_str = "..."
else:
ret_str = ", ".join(ret["type"] for ret in self.returns)
@ -233,13 +237,13 @@ class JitOperator:
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
pytype = _pytype_to_shape_fn_pytype(arg["type"])
default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
return f"{parameter_name}: {pytype}{default}"
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return _pytype_to_shape_fn_pytype(arg["pytype"])
return _pytype_to_shape_fn_pytype(arg["type"])
return self._get_function_signature(
"shape", parameter_decl_builder, ret_decl_builder
@ -255,10 +259,10 @@ class JitOperator:
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
pytype = _pytype_to_dtype_fn_pytype(arg["type"])
default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
if "Tensor" in arg["pytype"]:
if "Tensor" in arg["type"]:
return f"{parameter_name}_rank_dtype: {pytype}{default}"
return f"{parameter_name}: {pytype}{default}"
@ -267,9 +271,9 @@ class JitOperator:
# results of type `number`. Here we handle this case because
# `_pytype_to_dtype_fn_pytype` will replace `number` with
# `Union[int, float]`.
if arg["pytype"] in ["number", "Tensor"]:
if arg["type"] in ["number", "Tensor"]:
return "int"
return _pytype_to_dtype_fn_pytype(arg["pytype"])
return _pytype_to_dtype_fn_pytype(arg["type"])
return self._get_function_signature(
"dtype", parameter_decl_builder, ret_decl_builder
@ -285,13 +289,13 @@ class JitOperator:
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
pytype = _pytype_to_decomposition_fn_pytype(arg["type"])
default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
return f"{parameter_name}: {pytype}{default}"
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
return _pytype_to_decomposition_fn_pytype(arg["type"])
return self._get_function_signature(
"decomposition", parameter_decl_builder, ret_decl_builder
@ -332,9 +336,9 @@ class JitOperator:
p(f"namespace = {self.namespace}")
p(f"unqualified_name = {self.unqualified_name}")
p(f"overload_name = {self.overload_name}")
p(f"is_c10_op = {self.is_c10_op}")
p(f"is_vararg = {self.is_vararg}")
p(f"is_varret = {self.is_varret}")
# p(f"is_c10_op = {self.is_c10_op}")
# p(f"is_vararg = {self.is_vararg}")
# p(f"is_varret = {self.is_varret}")
p(f"is_mutable = {self.is_mutable}")
if any(ret["type"] == "Tensor" for ret in self.returns):
p(f"shape_function_signature = {self.get_shape_function_signature()}")