From bda2214d49e4fc01ca33675884de9291a412ed3e Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 10 Oct 2024 08:53:46 -0700 Subject: [PATCH] Rewrite get_registered_ops.cpp to python version --- .../build_tools/get_registered_ops.py | 60 +++++++++++++++++++ .../jit_ir_importer/build_tools/registry.py | 46 +++++++------- 2 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/get_registered_ops.py diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/get_registered_ops.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/get_registered_ops.py new file mode 100644 index 000000000..5ff3efdb0 --- /dev/null +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/get_registered_ops.py @@ -0,0 +1,60 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +"""Listing of the JIT operator registry, for use in generating the `torch` dialect. +""" + + +import torch +import torch._C +import pybind11 + +def get_registered_ops(): + results = [] + + # Walk the JIT operator registry to find all the ops that we might need + # for introspection / ODS generation. + # This registry contains a superset of the ops available to the dispatcher, + # since the JIT has its own dispatch mechanism that it uses to implement + # "prim" ops and a handful of "aten" ops that are effectively prim ops, such + # as `aten::__is__`. + for schema in torch._C._jit_get_all_schemas(): + record = {} + + record["name"] = schema.name + record["overload_name"] = schema.overload_name + record["is_mutable"] = schema.is_mutable + + arguments = [] + returns = [] + + def add_argument(container, arg): + arg_record = { + "name": arg.name, + "type": arg.type.annotation_str, + "kwarg_only" : arg.kwarg_only, + "is_out": arg.is_out, + } + if arg.default_value: + arg_record["default_value"] = arg.default_value + if arg.alias_info: + alias_info = { + "is_write": arg.alias_info.is_write, + "before_set": [str(symbol) for symbol in arg.alias_info.before_set], + "after_set": [str(symbol) for symbol in arg.alias_info.after_set], + } + arg_record["alias_info"] = alias_info + + container.append(arg_record) + + for argument in schema.arguments: + add_argument(arguments, argument) + for return_arg in schema.returns: + add_argument(returns, return_arg) + + record["arguments"] = arguments + record["returns"] = returns + results.append(record) + + return results \ No newline at end of file diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py index ec8317270..2054945f7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py @@ -14,9 +14,10 @@ import difflib from .utils import TextEmitter # Note that this utility exists only in the c-extension. -from torch_mlir._mlir_libs._jit_ir_importer import ( - get_registered_ops, -) # pytype: disable=import-error +# from torch_mlir._mlir_libs._jit_ir_importer import ( +# get_registered_ops, +# ) # pytype: disable=import-error +from .get_registered_ops import get_registered_ops def _rename_python_keyword_parameter_name(parameter_name: str) -> str: @@ -28,7 +29,7 @@ def _rename_python_keyword_parameter_name(parameter_name: str) -> str: def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: default = "" if "default_debug" in arg: - if "List" in arg["pytype"]: + if "List" in arg["type"]: # TorchScript doesn't allow lists as default parameters due # to the weird Python semantics of mutable default # arguments. So munge them into tuples, which work @@ -44,7 +45,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: default_debug = "()" else: default_debug = default_list.replace("[", "(").replace("]", ",)") - elif arg["pytype"] == "str": + elif arg["type"] == "str": default_debug = repr(arg["default_debug"]).replace("'", '"') else: default_debug = arg["default_debug"] @@ -119,9 +120,9 @@ class JitOperator: self.namespace = namespace self.unqualified_name = unqualified_name self.overload_name = op_info["name"][1] - self.is_c10_op = op_info["is_c10_op"] - self.is_vararg = op_info["is_vararg"] - self.is_varret = op_info["is_varret"] + # self.is_c10_op = op_info["is_c10_op"] + # self.is_vararg = op_info["is_vararg"] + # self.is_varret = op_info["is_varret"] self.is_mutable = op_info["is_mutable"] self.arguments = op_info["arguments"] self.returns = op_info["returns"] @@ -150,11 +151,14 @@ class JitOperator: concern for them) """ overload = "" if not self.overload_name else f".{self.overload_name}" - if self.is_vararg: + # Check if any argument is a variable argument + if any(arg["type"] == "..." for arg in self.arguments): arg_str = "..." else: arg_str = ", ".join(arg["type"] for arg in self.arguments) - if self.is_varret: + + # Check if any return type is a variable return type + if any(ret["type"] == "..." for ret in self.returns): ret_str = "..." else: ret_str = ", ".join(ret["type"] for ret in self.returns) @@ -233,13 +237,13 @@ class JitOperator: """ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - pytype = _pytype_to_shape_fn_pytype(arg["pytype"]) + pytype = _pytype_to_shape_fn_pytype(arg["type"]) default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - return _pytype_to_shape_fn_pytype(arg["pytype"]) + return _pytype_to_shape_fn_pytype(arg["type"]) return self._get_function_signature( "shape", parameter_decl_builder, ret_decl_builder @@ -255,10 +259,10 @@ class JitOperator: """ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - pytype = _pytype_to_dtype_fn_pytype(arg["pytype"]) + pytype = _pytype_to_dtype_fn_pytype(arg["type"]) default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) - if "Tensor" in arg["pytype"]: + if "Tensor" in arg["type"]: return f"{parameter_name}_rank_dtype: {pytype}{default}" return f"{parameter_name}: {pytype}{default}" @@ -267,9 +271,9 @@ class JitOperator: # results of type `number`. Here we handle this case because # `_pytype_to_dtype_fn_pytype` will replace `number` with # `Union[int, float]`. - if arg["pytype"] in ["number", "Tensor"]: + if arg["type"] in ["number", "Tensor"]: return "int" - return _pytype_to_dtype_fn_pytype(arg["pytype"]) + return _pytype_to_dtype_fn_pytype(arg["type"]) return self._get_function_signature( "dtype", parameter_decl_builder, ret_decl_builder @@ -285,13 +289,13 @@ class JitOperator: """ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"]) + pytype = _pytype_to_decomposition_fn_pytype(arg["type"]) default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: - return _pytype_to_decomposition_fn_pytype(arg["pytype"]) + return _pytype_to_decomposition_fn_pytype(arg["type"]) return self._get_function_signature( "decomposition", parameter_decl_builder, ret_decl_builder @@ -332,9 +336,9 @@ class JitOperator: p(f"namespace = {self.namespace}") p(f"unqualified_name = {self.unqualified_name}") p(f"overload_name = {self.overload_name}") - p(f"is_c10_op = {self.is_c10_op}") - p(f"is_vararg = {self.is_vararg}") - p(f"is_varret = {self.is_varret}") + # p(f"is_c10_op = {self.is_c10_op}") + # p(f"is_vararg = {self.is_vararg}") + # p(f"is_varret = {self.is_varret}") p(f"is_mutable = {self.is_mutable}") if any(ret["type"] == "Tensor" for ret in self.returns): p(f"shape_function_signature = {self.get_shape_function_signature()}")