Factor out definition of the _torch_mlir.get_registered_ops function.

It didn't make sense in the main registration file.
pull/309/head
Sean Silva 2021-09-17 04:29:25 +00:00
parent b738db34cd
commit 9e2442d6b0
4 changed files with 143 additions and 108 deletions

View File

@ -14,6 +14,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(TorchMLIRTorchPlugin SHARED add_library(TorchMLIRTorchPlugin SHARED
class_annotator.cpp class_annotator.cpp
get_registered_ops.cpp
function_importer.cpp function_importer.cpp
module_builder.cpp module_builder.cpp
node_importer.cpp node_importer.cpp

View File

@ -0,0 +1,116 @@
//===- get_registered_ops.cpp ---------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "get_registered_ops.h"
namespace py = pybind11;
static const char kGetRegisteredOpsDocstring[] =
R"(Gets a data structure of all registered ops.
The returned data reflects the metadata available in the Torch JIT's
registry at the time of this call. It includes both the operators available
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
uses to implement auxiliary operations that in the non-TorchScript case
are performed by Python itself.
This information is meant for various code generation tools.
Returns:
A list of records, one for each `torch::jit::Operator`. Known to the
Torch JIT operator registry. Each record is a dict of the following:
"name": tuple -> (qualified_name, overload)
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
or is a JIT-only op.
"is_vararg": bool -> Whether the op accepts variable arguments
"is_varret": bool -> Whether the op produces variable returns
"is_mutable": bool -> Whether the op potentially mutates any operand
"arguments" and "returns": List[Dict] -> Having keys:
"type": str -> PyTorch type name as in op signatures
"pytype": str -> PyType style type annotation
"N": (optional) int -> For list types, the arity
"default_debug": (optional) str -> Debug printout of the default value
"alias_info": Dict -> Alias info with keys "before" and "after"
)";
static py::list getRegisteredOps() {
py::list results;
// Walk the JIT operator registry to find all the ops that we might need
// for introspection / ODS generation.
// This registry contains a superset of the ops available to the dispatcher,
// since the JIT has its own dispatch mechanism that it uses to implement
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
// as `aten::__is__`.
for (const std::shared_ptr<torch::jit::Operator> &op :
torch::jit::getAllOperators()) {
const c10::FunctionSchema &schema = op->schema();
py::dict record;
{
py::tuple name(2);
name[0] = schema.name();
name[1] = schema.overload_name();
record["name"] = std::move(name);
}
record["is_c10_op"] = op->isC10Op();
record["is_vararg"] = schema.is_vararg();
record["is_varret"] = schema.is_varret();
record["is_mutable"] = schema.is_mutable();
py::list arguments;
py::list returns;
auto addArgument = [](py::list &container, const c10::Argument &arg) {
py::dict argRecord;
argRecord["name"] = arg.name();
argRecord["type"] = arg.type()->str();
argRecord["pytype"] = arg.type()->annotation_str();
if (arg.N())
argRecord["N"] = *arg.N();
// TODO: If the default value becomes useful, switch on it and return
// a real value, not just a string print.
if (arg.default_value()) {
std::stringstream sout;
sout << *arg.default_value();
argRecord["default_debug"] = sout.str();
}
if (arg.alias_info()) {
py::dict aliasInfo;
py::list before;
py::list after;
for (auto &symbol : arg.alias_info()->beforeSets()) {
before.append(std::string(symbol.toQualString()));
}
for (auto &symbol : arg.alias_info()->afterSets()) {
after.append(std::string(symbol.toQualString()));
}
aliasInfo["is_write"] = arg.alias_info()->isWrite();
aliasInfo["before"] = std::move(before);
aliasInfo["after"] = std::move(after);
argRecord["alias_info"] = std::move(aliasInfo);
}
container.append(std::move(argRecord));
};
for (auto &argument : schema.arguments()) {
addArgument(arguments, argument);
}
for (auto &returnArg : schema.returns()) {
addArgument(returns, returnArg);
}
record["arguments"] = std::move(arguments);
record["returns"] = std::move(returns);
results.append(std::move(record));
}
return results;
}
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) {
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
}

View File

@ -0,0 +1,24 @@
//===- get_registered_ops.h -------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
//
// Listing of the JIT operator registry, for use in generating the `torch`
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H
#define TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H
#include "pybind.h"
namespace torch_mlir {
void initGetRegisteredOpsBindings(py::module &m);
} // namespace torch_mlir
#endif // TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H

View File

@ -10,119 +10,13 @@
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include "class_annotator.h" #include "class_annotator.h"
#include "get_registered_ops.h"
#include "module_builder.h" #include "module_builder.h"
using namespace torch_mlir; using namespace torch_mlir;
namespace py = pybind11;
namespace {
static const char kGetRegisteredOpsDocstring[] =
R"(Gets a data structure of all registered ops.
The returned data reflects the metadata available in the Torch JIT's
registry at the time of this call. It includes both the operators available
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
uses to implement auxiliary operations that in the non-TorchScript case
are performed by Python itself.
This information is meant for various code generation tools.
Returns:
A list of records, one for each `torch::jit::Operator`. Known to the
Torch JIT operator registry. Each record is a dict of the following:
"name": tuple -> (qualified_name, overload)
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
or is a JIT-only op.
"is_vararg": bool -> Whether the op accepts variable arguments
"is_varret": bool -> Whether the op produces variable returns
"is_mutable": bool -> Whether the op potentially mutates any operand
"arguments" and "returns": List[Dict] -> Having keys:
"type": str -> PyTorch type name as in op signatures
"pytype": str -> PyType style type annotation
"N": (optional) int -> For list types, the arity
"default_debug": (optional) str -> Debug printout of the default value
"alias_info": Dict -> Alias info with keys "before" and "after"
)";
py::list GetRegisteredOps() {
py::list results;
// Walk the JIT operator registry to find all the ops that we might need
// for introspection / ODS generation.
// This registry contains a superset of the ops available to the dispatcher,
// since the JIT has its own dispatch mechanism that it uses to implement
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
// as `aten::__is__`.
for (const std::shared_ptr<torch::jit::Operator> &op :
torch::jit::getAllOperators()) {
const c10::FunctionSchema &schema = op->schema();
py::dict record;
{
py::tuple name(2);
name[0] = schema.name();
name[1] = schema.overload_name();
record["name"] = std::move(name);
}
record["is_c10_op"] = op->isC10Op();
record["is_vararg"] = schema.is_vararg();
record["is_varret"] = schema.is_varret();
record["is_mutable"] = schema.is_mutable();
py::list arguments;
py::list returns;
auto addArgument = [](py::list &container, const c10::Argument &arg) {
py::dict argRecord;
argRecord["name"] = arg.name();
argRecord["type"] = arg.type()->str();
argRecord["pytype"] = arg.type()->annotation_str();
if (arg.N())
argRecord["N"] = *arg.N();
// TODO: If the default value becomes useful, switch on it and return
// a real value, not just a string print.
if (arg.default_value()) {
std::stringstream sout;
sout << *arg.default_value();
argRecord["default_debug"] = sout.str();
}
if (arg.alias_info()) {
py::dict aliasInfo;
py::list before;
py::list after;
for (auto &symbol : arg.alias_info()->beforeSets()) {
before.append(std::string(symbol.toQualString()));
}
for (auto &symbol : arg.alias_info()->afterSets()) {
after.append(std::string(symbol.toQualString()));
}
aliasInfo["is_write"] = arg.alias_info()->isWrite();
aliasInfo["before"] = std::move(before);
aliasInfo["after"] = std::move(after);
argRecord["alias_info"] = std::move(aliasInfo);
}
container.append(std::move(argRecord));
};
for (auto &argument : schema.arguments()) {
addArgument(arguments, argument);
}
for (auto &returnArg : schema.returns()) {
addArgument(returns, returnArg);
}
record["arguments"] = std::move(arguments);
record["returns"] = std::move(returns);
results.append(std::move(record));
}
return results;
}
} // namespace
PYBIND11_MODULE(_torch_mlir, m) { PYBIND11_MODULE(_torch_mlir, m) {
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
ModuleBuilder::bind(m); ModuleBuilder::bind(m);
initClassAnnotatorBindings(m); initClassAnnotatorBindings(m);
initGetRegisteredOpsBindings(m);
} }