2020-09-25 07:26:29 +08:00
|
|
|
//===- python_bindings.cpp --------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-10-02 09:59:58 +08:00
|
|
|
#include "../pybind.h"
|
2020-10-30 08:41:15 +08:00
|
|
|
#include "debug.h"
|
2020-10-02 09:59:58 +08:00
|
|
|
|
2020-09-26 09:13:16 +08:00
|
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
2020-09-25 07:26:29 +08:00
|
|
|
|
2020-09-26 09:13:16 +08:00
|
|
|
#include "../init_python_bindings.h"
|
|
|
|
#include "acap_dispatch.h"
|
2020-09-29 09:36:00 +08:00
|
|
|
#include "module_builder.h"
|
2021-02-20 08:21:21 +08:00
|
|
|
#include "class_annotator.h"
|
2020-09-25 07:26:29 +08:00
|
|
|
|
2020-09-26 09:13:16 +08:00
|
|
|
using namespace torch_mlir;
|
2020-09-25 07:26:29 +08:00
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
static const char kGetRegisteredOpsDocstring[] =
|
|
|
|
R"(Gets a data structure of all registered ops.
|
|
|
|
|
2021-04-28 06:15:50 +08:00
|
|
|
The returned data reflects the metadata available in the Torch JIT's
|
|
|
|
registry at the time of this call. It includes both the operators available
|
|
|
|
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
|
|
|
|
uses to implement auxiliary operations that in the non-TorchScript case
|
|
|
|
are performed by Python itself.
|
|
|
|
|
|
|
|
This information is meant for various code generation tools.
|
2020-09-25 07:26:29 +08:00
|
|
|
|
|
|
|
Returns:
|
2021-04-28 06:15:50 +08:00
|
|
|
A list of records, one for each `torch::jit::Operator`. Known to the
|
|
|
|
Torch JIT operator registry. Each record is a dict of the following:
|
2020-09-25 07:26:29 +08:00
|
|
|
"name": tuple -> (qualified_name, overload)
|
2021-04-28 06:15:50 +08:00
|
|
|
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
|
|
|
|
or is a JIT-only op.
|
2020-09-25 07:26:29 +08:00
|
|
|
"is_vararg": bool -> Whether the op accepts variable arguments
|
|
|
|
"is_varret": bool -> Whether the op produces variable returns
|
2021-04-28 06:15:50 +08:00
|
|
|
"is_mutable": bool -> Whether the op potentially mutates any operand
|
2020-09-25 07:26:29 +08:00
|
|
|
"arguments" and "returns": List[Dict] -> Having keys:
|
|
|
|
"type": str -> PyTorch type name as in op signatures
|
|
|
|
"pytype": str -> PyType style type annotation
|
|
|
|
"N": (optional) int -> For list types, the arity
|
|
|
|
"default_debug": (optional) str -> Debug printout of the default value
|
|
|
|
"alias_info": Dict -> Alias info with keys "before" and "after"
|
|
|
|
)";
|
|
|
|
|
|
|
|
py::list GetRegisteredOps() {
|
|
|
|
py::list results;
|
|
|
|
|
2021-04-28 06:15:50 +08:00
|
|
|
// Walk the JIT operator registry to find all the ops that we might need
|
|
|
|
// for introspection / ODS generation.
|
|
|
|
// This registry contains a superset of the ops available to the dispatcher,
|
|
|
|
// since the JIT has its own dispatch mechanism that it uses to implement
|
|
|
|
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
|
|
|
// as `aten::__is__`.
|
|
|
|
for (const std::shared_ptr<torch::jit::Operator> &op :
|
|
|
|
torch::jit::getAllOperators()) {
|
|
|
|
const c10::FunctionSchema &schema = op->schema();
|
|
|
|
|
|
|
|
py::dict record;
|
|
|
|
{
|
|
|
|
py::tuple name(2);
|
|
|
|
name[0] = schema.name();
|
|
|
|
name[1] = schema.overload_name();
|
|
|
|
record["name"] = std::move(name);
|
|
|
|
}
|
|
|
|
|
|
|
|
record["is_c10_op"] = op->isC10Op();
|
|
|
|
record["is_vararg"] = schema.is_vararg();
|
|
|
|
record["is_varret"] = schema.is_varret();
|
|
|
|
record["is_mutable"] = schema.is_mutable();
|
|
|
|
|
|
|
|
py::list arguments;
|
|
|
|
py::list returns;
|
|
|
|
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
|
|
|
py::dict argRecord;
|
|
|
|
argRecord["name"] = arg.name();
|
|
|
|
argRecord["type"] = arg.type()->str();
|
|
|
|
argRecord["pytype"] = arg.type()->annotation_str();
|
|
|
|
if (arg.N())
|
|
|
|
argRecord["N"] = *arg.N();
|
|
|
|
// TODO: If the default value becomes useful, switch on it and return
|
|
|
|
// a real value, not just a string print.
|
|
|
|
if (arg.default_value()) {
|
|
|
|
std::stringstream sout;
|
|
|
|
sout << *arg.default_value();
|
|
|
|
argRecord["default_debug"] = sout.str();
|
|
|
|
}
|
|
|
|
if (arg.alias_info()) {
|
|
|
|
py::dict aliasInfo;
|
|
|
|
py::list before;
|
|
|
|
py::list after;
|
|
|
|
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
|
|
|
before.append(std::string(symbol.toQualString()));
|
2020-09-25 07:26:29 +08:00
|
|
|
}
|
2021-04-28 06:15:50 +08:00
|
|
|
for (auto &symbol : arg.alias_info()->afterSets()) {
|
|
|
|
after.append(std::string(symbol.toQualString()));
|
2020-09-25 07:26:29 +08:00
|
|
|
}
|
2021-04-28 06:15:50 +08:00
|
|
|
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
|
|
|
aliasInfo["before"] = std::move(before);
|
|
|
|
aliasInfo["after"] = std::move(after);
|
|
|
|
argRecord["alias_info"] = std::move(aliasInfo);
|
|
|
|
}
|
|
|
|
|
|
|
|
container.append(std::move(argRecord));
|
|
|
|
};
|
|
|
|
for (auto &argument : schema.arguments()) {
|
|
|
|
addArgument(arguments, argument);
|
|
|
|
}
|
|
|
|
for (auto &returnArg : schema.returns()) {
|
|
|
|
addArgument(returns, returnArg);
|
|
|
|
}
|
|
|
|
record["arguments"] = std::move(arguments);
|
|
|
|
record["returns"] = std::move(returns);
|
|
|
|
results.append(std::move(record));
|
|
|
|
}
|
|
|
|
|
2020-09-25 07:26:29 +08:00
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
2020-11-13 14:27:05 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void torch_mlir::InitBuilderBindings(py::module &m) {
|
2020-10-30 08:41:15 +08:00
|
|
|
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
|
|
|
|
|
2020-09-26 09:13:16 +08:00
|
|
|
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
|
|
|
"AcapController")
|
|
|
|
.def("__enter__", &AcapController::contextEnter)
|
|
|
|
.def("__exit__", &AcapController::contextExit)
|
2020-10-30 08:41:15 +08:00
|
|
|
.def("returns", &AcapController::returns);
|
2020-09-25 07:26:29 +08:00
|
|
|
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
2020-09-29 09:36:00 +08:00
|
|
|
|
|
|
|
ModuleBuilder::bind(m);
|
2021-02-20 08:21:21 +08:00
|
|
|
|
|
|
|
initClassAnnotatorBindings(m);
|
2020-09-25 07:26:29 +08:00
|
|
|
}
|