mirror of https://github.com/llvm/torch-mlir
Factor out definition of the _torch_mlir.get_registered_ops function.
It didn't make sense in the main registration file.pull/309/head
parent
b738db34cd
commit
9e2442d6b0
|
@ -14,6 +14,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
add_library(TorchMLIRTorchPlugin SHARED
|
add_library(TorchMLIRTorchPlugin SHARED
|
||||||
class_annotator.cpp
|
class_annotator.cpp
|
||||||
|
get_registered_ops.cpp
|
||||||
function_importer.cpp
|
function_importer.cpp
|
||||||
module_builder.cpp
|
module_builder.cpp
|
||||||
node_importer.cpp
|
node_importer.cpp
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
//===- get_registered_ops.cpp ---------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "get_registered_ops.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
static const char kGetRegisteredOpsDocstring[] =
|
||||||
|
R"(Gets a data structure of all registered ops.
|
||||||
|
|
||||||
|
The returned data reflects the metadata available in the Torch JIT's
|
||||||
|
registry at the time of this call. It includes both the operators available
|
||||||
|
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
|
||||||
|
uses to implement auxiliary operations that in the non-TorchScript case
|
||||||
|
are performed by Python itself.
|
||||||
|
|
||||||
|
This information is meant for various code generation tools.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of records, one for each `torch::jit::Operator`. Known to the
|
||||||
|
Torch JIT operator registry. Each record is a dict of the following:
|
||||||
|
"name": tuple -> (qualified_name, overload)
|
||||||
|
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
|
||||||
|
or is a JIT-only op.
|
||||||
|
"is_vararg": bool -> Whether the op accepts variable arguments
|
||||||
|
"is_varret": bool -> Whether the op produces variable returns
|
||||||
|
"is_mutable": bool -> Whether the op potentially mutates any operand
|
||||||
|
"arguments" and "returns": List[Dict] -> Having keys:
|
||||||
|
"type": str -> PyTorch type name as in op signatures
|
||||||
|
"pytype": str -> PyType style type annotation
|
||||||
|
"N": (optional) int -> For list types, the arity
|
||||||
|
"default_debug": (optional) str -> Debug printout of the default value
|
||||||
|
"alias_info": Dict -> Alias info with keys "before" and "after"
|
||||||
|
)";
|
||||||
|
|
||||||
|
static py::list getRegisteredOps() {
|
||||||
|
py::list results;
|
||||||
|
|
||||||
|
// Walk the JIT operator registry to find all the ops that we might need
|
||||||
|
// for introspection / ODS generation.
|
||||||
|
// This registry contains a superset of the ops available to the dispatcher,
|
||||||
|
// since the JIT has its own dispatch mechanism that it uses to implement
|
||||||
|
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||||
|
// as `aten::__is__`.
|
||||||
|
for (const std::shared_ptr<torch::jit::Operator> &op :
|
||||||
|
torch::jit::getAllOperators()) {
|
||||||
|
const c10::FunctionSchema &schema = op->schema();
|
||||||
|
|
||||||
|
py::dict record;
|
||||||
|
{
|
||||||
|
py::tuple name(2);
|
||||||
|
name[0] = schema.name();
|
||||||
|
name[1] = schema.overload_name();
|
||||||
|
record["name"] = std::move(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
record["is_c10_op"] = op->isC10Op();
|
||||||
|
record["is_vararg"] = schema.is_vararg();
|
||||||
|
record["is_varret"] = schema.is_varret();
|
||||||
|
record["is_mutable"] = schema.is_mutable();
|
||||||
|
|
||||||
|
py::list arguments;
|
||||||
|
py::list returns;
|
||||||
|
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
||||||
|
py::dict argRecord;
|
||||||
|
argRecord["name"] = arg.name();
|
||||||
|
argRecord["type"] = arg.type()->str();
|
||||||
|
argRecord["pytype"] = arg.type()->annotation_str();
|
||||||
|
if (arg.N())
|
||||||
|
argRecord["N"] = *arg.N();
|
||||||
|
// TODO: If the default value becomes useful, switch on it and return
|
||||||
|
// a real value, not just a string print.
|
||||||
|
if (arg.default_value()) {
|
||||||
|
std::stringstream sout;
|
||||||
|
sout << *arg.default_value();
|
||||||
|
argRecord["default_debug"] = sout.str();
|
||||||
|
}
|
||||||
|
if (arg.alias_info()) {
|
||||||
|
py::dict aliasInfo;
|
||||||
|
py::list before;
|
||||||
|
py::list after;
|
||||||
|
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
||||||
|
before.append(std::string(symbol.toQualString()));
|
||||||
|
}
|
||||||
|
for (auto &symbol : arg.alias_info()->afterSets()) {
|
||||||
|
after.append(std::string(symbol.toQualString()));
|
||||||
|
}
|
||||||
|
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||||
|
aliasInfo["before"] = std::move(before);
|
||||||
|
aliasInfo["after"] = std::move(after);
|
||||||
|
argRecord["alias_info"] = std::move(aliasInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
container.append(std::move(argRecord));
|
||||||
|
};
|
||||||
|
for (auto &argument : schema.arguments()) {
|
||||||
|
addArgument(arguments, argument);
|
||||||
|
}
|
||||||
|
for (auto &returnArg : schema.returns()) {
|
||||||
|
addArgument(returns, returnArg);
|
||||||
|
}
|
||||||
|
record["arguments"] = std::move(arguments);
|
||||||
|
record["returns"] = std::move(returns);
|
||||||
|
results.append(std::move(record));
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) {
|
||||||
|
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
//===- get_registered_ops.h -------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Listing of the JIT operator registry, for use in generating the `torch`
|
||||||
|
// dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H
|
||||||
|
#define TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H
|
||||||
|
|
||||||
|
#include "pybind.h"
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
void initGetRegisteredOpsBindings(py::module &m);
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIRPLUGIN_CSRC_GETREGISTEREDOPS_H
|
|
@ -10,119 +10,13 @@
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "class_annotator.h"
|
||||||
|
#include "get_registered_ops.h"
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
static const char kGetRegisteredOpsDocstring[] =
|
|
||||||
R"(Gets a data structure of all registered ops.
|
|
||||||
|
|
||||||
The returned data reflects the metadata available in the Torch JIT's
|
|
||||||
registry at the time of this call. It includes both the operators available
|
|
||||||
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
|
|
||||||
uses to implement auxiliary operations that in the non-TorchScript case
|
|
||||||
are performed by Python itself.
|
|
||||||
|
|
||||||
This information is meant for various code generation tools.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of records, one for each `torch::jit::Operator`. Known to the
|
|
||||||
Torch JIT operator registry. Each record is a dict of the following:
|
|
||||||
"name": tuple -> (qualified_name, overload)
|
|
||||||
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
|
|
||||||
or is a JIT-only op.
|
|
||||||
"is_vararg": bool -> Whether the op accepts variable arguments
|
|
||||||
"is_varret": bool -> Whether the op produces variable returns
|
|
||||||
"is_mutable": bool -> Whether the op potentially mutates any operand
|
|
||||||
"arguments" and "returns": List[Dict] -> Having keys:
|
|
||||||
"type": str -> PyTorch type name as in op signatures
|
|
||||||
"pytype": str -> PyType style type annotation
|
|
||||||
"N": (optional) int -> For list types, the arity
|
|
||||||
"default_debug": (optional) str -> Debug printout of the default value
|
|
||||||
"alias_info": Dict -> Alias info with keys "before" and "after"
|
|
||||||
)";
|
|
||||||
|
|
||||||
py::list GetRegisteredOps() {
|
|
||||||
py::list results;
|
|
||||||
|
|
||||||
// Walk the JIT operator registry to find all the ops that we might need
|
|
||||||
// for introspection / ODS generation.
|
|
||||||
// This registry contains a superset of the ops available to the dispatcher,
|
|
||||||
// since the JIT has its own dispatch mechanism that it uses to implement
|
|
||||||
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
|
||||||
// as `aten::__is__`.
|
|
||||||
for (const std::shared_ptr<torch::jit::Operator> &op :
|
|
||||||
torch::jit::getAllOperators()) {
|
|
||||||
const c10::FunctionSchema &schema = op->schema();
|
|
||||||
|
|
||||||
py::dict record;
|
|
||||||
{
|
|
||||||
py::tuple name(2);
|
|
||||||
name[0] = schema.name();
|
|
||||||
name[1] = schema.overload_name();
|
|
||||||
record["name"] = std::move(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
record["is_c10_op"] = op->isC10Op();
|
|
||||||
record["is_vararg"] = schema.is_vararg();
|
|
||||||
record["is_varret"] = schema.is_varret();
|
|
||||||
record["is_mutable"] = schema.is_mutable();
|
|
||||||
|
|
||||||
py::list arguments;
|
|
||||||
py::list returns;
|
|
||||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
|
||||||
py::dict argRecord;
|
|
||||||
argRecord["name"] = arg.name();
|
|
||||||
argRecord["type"] = arg.type()->str();
|
|
||||||
argRecord["pytype"] = arg.type()->annotation_str();
|
|
||||||
if (arg.N())
|
|
||||||
argRecord["N"] = *arg.N();
|
|
||||||
// TODO: If the default value becomes useful, switch on it and return
|
|
||||||
// a real value, not just a string print.
|
|
||||||
if (arg.default_value()) {
|
|
||||||
std::stringstream sout;
|
|
||||||
sout << *arg.default_value();
|
|
||||||
argRecord["default_debug"] = sout.str();
|
|
||||||
}
|
|
||||||
if (arg.alias_info()) {
|
|
||||||
py::dict aliasInfo;
|
|
||||||
py::list before;
|
|
||||||
py::list after;
|
|
||||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
|
||||||
before.append(std::string(symbol.toQualString()));
|
|
||||||
}
|
|
||||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
|
||||||
after.append(std::string(symbol.toQualString()));
|
|
||||||
}
|
|
||||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
|
||||||
aliasInfo["before"] = std::move(before);
|
|
||||||
aliasInfo["after"] = std::move(after);
|
|
||||||
argRecord["alias_info"] = std::move(aliasInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
container.append(std::move(argRecord));
|
|
||||||
};
|
|
||||||
for (auto &argument : schema.arguments()) {
|
|
||||||
addArgument(arguments, argument);
|
|
||||||
}
|
|
||||||
for (auto &returnArg : schema.returns()) {
|
|
||||||
addArgument(returns, returnArg);
|
|
||||||
}
|
|
||||||
record["arguments"] = std::move(arguments);
|
|
||||||
record["returns"] = std::move(returns);
|
|
||||||
results.append(std::move(record));
|
|
||||||
}
|
|
||||||
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
PYBIND11_MODULE(_torch_mlir, m) {
|
PYBIND11_MODULE(_torch_mlir, m) {
|
||||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
|
||||||
ModuleBuilder::bind(m);
|
ModuleBuilder::bind(m);
|
||||||
initClassAnnotatorBindings(m);
|
initClassAnnotatorBindings(m);
|
||||||
|
initGetRegisteredOpsBindings(m);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue