diff --git a/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt b/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt index c3a18951f..13c11b386 100644 --- a/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt +++ b/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt @@ -14,6 +14,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(TorchMLIRTorchPlugin SHARED class_annotator.cpp + get_registered_ops.cpp function_importer.cpp module_builder.cpp node_importer.cpp diff --git a/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.cpp b/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.cpp new file mode 100644 index 000000000..96018dc59 --- /dev/null +++ b/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.cpp @@ -0,0 +1,116 @@ +//===- get_registered_ops.cpp ---------------------------------------------===// +// +// This file is licensed under a pytorch-style license +// See LICENSE for license information. +// +//===----------------------------------------------------------------------===// + +#include "get_registered_ops.h" + +namespace py = pybind11; + +static const char kGetRegisteredOpsDocstring[] = + R"(Gets a data structure of all registered ops. + +The returned data reflects the metadata available in the Torch JIT's +registry at the time of this call. It includes both the operators available +in the c10 dispatcher and an auxiliary set of operators that the Torch JIT +uses to implement auxiliary operations that in the non-TorchScript case +are performed by Python itself. + +This information is meant for various code generation tools. + +Returns: + A list of records, one for each `torch::jit::Operator`. Known to the + Torch JIT operator registry. Each record is a dict of the following: + "name": tuple -> (qualified_name, overload) + "is_c10_op": bool -> Whether the op is in the c10 dispatcher registry, + or is a JIT-only op. + "is_vararg": bool -> Whether the op accepts variable arguments + "is_varret": bool -> Whether the op produces variable returns + "is_mutable": bool -> Whether the op potentially mutates any operand + "arguments" and "returns": List[Dict] -> Having keys: + "type": str -> PyTorch type name as in op signatures + "pytype": str -> PyType style type annotation + "N": (optional) int -> For list types, the arity + "default_debug": (optional) str -> Debug printout of the default value + "alias_info": Dict -> Alias info with keys "before" and "after" +)"; + +static py::list getRegisteredOps() { + py::list results; + + // Walk the JIT operator registry to find all the ops that we might need + // for introspection / ODS generation. + // This registry contains a superset of the ops available to the dispatcher, + // since the JIT has its own dispatch mechanism that it uses to implement + // "prim" ops and a handful of "aten" ops that are effectively prim ops, such + // as `aten::__is__`. + for (const std::shared_ptr &op : + torch::jit::getAllOperators()) { + const c10::FunctionSchema &schema = op->schema(); + + py::dict record; + { + py::tuple name(2); + name[0] = schema.name(); + name[1] = schema.overload_name(); + record["name"] = std::move(name); + } + + record["is_c10_op"] = op->isC10Op(); + record["is_vararg"] = schema.is_vararg(); + record["is_varret"] = schema.is_varret(); + record["is_mutable"] = schema.is_mutable(); + + py::list arguments; + py::list returns; + auto addArgument = [](py::list &container, const c10::Argument &arg) { + py::dict argRecord; + argRecord["name"] = arg.name(); + argRecord["type"] = arg.type()->str(); + argRecord["pytype"] = arg.type()->annotation_str(); + if (arg.N()) + argRecord["N"] = *arg.N(); + // TODO: If the default value becomes useful, switch on it and return + // a real value, not just a string print. + if (arg.default_value()) { + std::stringstream sout; + sout << *arg.default_value(); + argRecord["default_debug"] = sout.str(); + } + if (arg.alias_info()) { + py::dict aliasInfo; + py::list before; + py::list after; + for (auto &symbol : arg.alias_info()->beforeSets()) { + before.append(std::string(symbol.toQualString())); + } + for (auto &symbol : arg.alias_info()->afterSets()) { + after.append(std::string(symbol.toQualString())); + } + aliasInfo["is_write"] = arg.alias_info()->isWrite(); + aliasInfo["before"] = std::move(before); + aliasInfo["after"] = std::move(after); + argRecord["alias_info"] = std::move(aliasInfo); + } + + container.append(std::move(argRecord)); + }; + for (auto &argument : schema.arguments()) { + addArgument(arguments, argument); + } + for (auto &returnArg : schema.returns()) { + addArgument(returns, returnArg); + } + record["arguments"] = std::move(arguments); + record["returns"] = std::move(returns); + results.append(std::move(record)); + } + + return results; +} + +void torch_mlir::initGetRegisteredOpsBindings(py::module &m) { + m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring); +} diff --git a/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.h b/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.h new file mode 100644 index 000000000..b4cbd639a --- /dev/null +++ b/external/torch-mlir/TorchPlugin/csrc/get_registered_ops.h @@ -0,0 +1,24 @@ +//===- get_registered_ops.h -------------------------------------*- C++ -*-===// +// +// This file is licensed under a pytorch-style license +// See LICENSE for license information. +// +//===----------------------------------------------------------------------===// +// +// Listing of the JIT operator registry, for use in generating the `torch` +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H +#define TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H + +#include "pybind.h" + +namespace torch_mlir { + +void initGetRegisteredOpsBindings(py::module &m); + +} // namespace torch_mlir + +#endif // TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H diff --git a/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp b/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp index 0a091ebb7..c453f2173 100644 --- a/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp +++ b/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp @@ -10,119 +10,13 @@ #include #include "class_annotator.h" +#include "get_registered_ops.h" #include "module_builder.h" using namespace torch_mlir; -namespace py = pybind11; - -namespace { - -static const char kGetRegisteredOpsDocstring[] = - R"(Gets a data structure of all registered ops. - -The returned data reflects the metadata available in the Torch JIT's -registry at the time of this call. It includes both the operators available -in the c10 dispatcher and an auxiliary set of operators that the Torch JIT -uses to implement auxiliary operations that in the non-TorchScript case -are performed by Python itself. - -This information is meant for various code generation tools. - -Returns: - A list of records, one for each `torch::jit::Operator`. Known to the - Torch JIT operator registry. Each record is a dict of the following: - "name": tuple -> (qualified_name, overload) - "is_c10_op": bool -> Whether the op is in the c10 dispatcher registry, - or is a JIT-only op. - "is_vararg": bool -> Whether the op accepts variable arguments - "is_varret": bool -> Whether the op produces variable returns - "is_mutable": bool -> Whether the op potentially mutates any operand - "arguments" and "returns": List[Dict] -> Having keys: - "type": str -> PyTorch type name as in op signatures - "pytype": str -> PyType style type annotation - "N": (optional) int -> For list types, the arity - "default_debug": (optional) str -> Debug printout of the default value - "alias_info": Dict -> Alias info with keys "before" and "after" -)"; - -py::list GetRegisteredOps() { - py::list results; - - // Walk the JIT operator registry to find all the ops that we might need - // for introspection / ODS generation. - // This registry contains a superset of the ops available to the dispatcher, - // since the JIT has its own dispatch mechanism that it uses to implement - // "prim" ops and a handful of "aten" ops that are effectively prim ops, such - // as `aten::__is__`. - for (const std::shared_ptr &op : - torch::jit::getAllOperators()) { - const c10::FunctionSchema &schema = op->schema(); - - py::dict record; - { - py::tuple name(2); - name[0] = schema.name(); - name[1] = schema.overload_name(); - record["name"] = std::move(name); - } - - record["is_c10_op"] = op->isC10Op(); - record["is_vararg"] = schema.is_vararg(); - record["is_varret"] = schema.is_varret(); - record["is_mutable"] = schema.is_mutable(); - - py::list arguments; - py::list returns; - auto addArgument = [](py::list &container, const c10::Argument &arg) { - py::dict argRecord; - argRecord["name"] = arg.name(); - argRecord["type"] = arg.type()->str(); - argRecord["pytype"] = arg.type()->annotation_str(); - if (arg.N()) - argRecord["N"] = *arg.N(); - // TODO: If the default value becomes useful, switch on it and return - // a real value, not just a string print. - if (arg.default_value()) { - std::stringstream sout; - sout << *arg.default_value(); - argRecord["default_debug"] = sout.str(); - } - if (arg.alias_info()) { - py::dict aliasInfo; - py::list before; - py::list after; - for (auto &symbol : arg.alias_info()->beforeSets()) { - before.append(std::string(symbol.toQualString())); - } - for (auto &symbol : arg.alias_info()->afterSets()) { - after.append(std::string(symbol.toQualString())); - } - aliasInfo["is_write"] = arg.alias_info()->isWrite(); - aliasInfo["before"] = std::move(before); - aliasInfo["after"] = std::move(after); - argRecord["alias_info"] = std::move(aliasInfo); - } - - container.append(std::move(argRecord)); - }; - for (auto &argument : schema.arguments()) { - addArgument(arguments, argument); - } - for (auto &returnArg : schema.returns()) { - addArgument(returns, returnArg); - } - record["arguments"] = std::move(arguments); - record["returns"] = std::move(returns); - results.append(std::move(record)); - } - - return results; -} - -} // namespace PYBIND11_MODULE(_torch_mlir, m) { - m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring); ModuleBuilder::bind(m); initClassAnnotatorBindings(m); + initGetRegisteredOpsBindings(m); }