mirror of https://github.com/llvm/torch-mlir
Remove pybind deps from importer and annotator (#903)
* Remove pybind deps from importer and annotator * Rename files to class_annotator_pybind.cpp/.hpull/918/head snapshot-20220608.497
parent
e1b38e74dd
commit
bd53998da8
|
@ -12,6 +12,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
|||
|
||||
add_library(TorchMLIRJITIRImporter MODULE
|
||||
class_annotator.cpp
|
||||
class_annotator_pybind.cpp
|
||||
get_registered_ops.cpp
|
||||
function_importer.cpp
|
||||
module_builder.cpp
|
||||
|
|
|
@ -11,8 +11,6 @@
|
|||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "torch/csrc/Dtype.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -150,21 +148,10 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
|||
return *it->second;
|
||||
}
|
||||
|
||||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
||||
return dtype->scalar_type;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "unsupported scalar type '" << obj << "'";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||
py::list pyArgAnnotations,
|
||||
std::vector<ArgAnnotation> argAnnotations,
|
||||
torch::jit::Function *function) {
|
||||
if (pyArgAnnotations.size() != function->num_inputs()) {
|
||||
if (argAnnotations.size() != function->num_inputs()) {
|
||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
||||
"function parameter (including self).");
|
||||
}
|
||||
|
@ -172,29 +159,13 @@ static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
|||
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
||||
ArgAnnotation{});
|
||||
}
|
||||
std::vector<ArgAnnotation> &argAnnotations =
|
||||
methodAnnotation.argAnnotations.value();
|
||||
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
|
||||
if (pyArgAnnotations[i].is_none()) {
|
||||
continue;
|
||||
}
|
||||
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
|
||||
auto shape = tuple[0];
|
||||
auto dtype = tuple[1];
|
||||
auto hasValueSemantics = tuple[2];
|
||||
if (!shape.is_none()) {
|
||||
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
|
||||
}
|
||||
if (!dtype.is_none()) {
|
||||
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
|
||||
}
|
||||
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
|
||||
};
|
||||
|
||||
methodAnnotation.argAnnotations = argAnnotations;
|
||||
}
|
||||
|
||||
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path,
|
||||
py::list argAnnotations) {
|
||||
std::vector<ArgAnnotation> argAnnotations) {
|
||||
if (path.size() == 0) {
|
||||
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
||||
"shapes/dtypes of a method of a class.");
|
||||
|
@ -331,12 +302,3 @@ std::string ClassAnnotator::toString() {
|
|||
ss << "}\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||
.def(py::init<>())
|
||||
.def("exportPath", &ClassAnnotator::exportPath)
|
||||
.def("exportNone", &ClassAnnotator::exportNone)
|
||||
.def("annotateArgs", &ClassAnnotator::annotateArgs)
|
||||
.def("__repr__", &ClassAnnotator::toString);
|
||||
}
|
||||
|
|
|
@ -25,8 +25,6 @@
|
|||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
#include "pybind.h"
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
// An annotation on a class's attribute (corresponds to a c10::ClassAttribute).
|
||||
|
@ -162,7 +160,8 @@ public:
|
|||
// These will be put into an `ArgAnnotation` struct -- see there for
|
||||
// precise definitions of the promised semantics of each entry.
|
||||
void annotateArgs(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path, py::list argAnnotations);
|
||||
std::vector<std::string> path,
|
||||
std::vector<ArgAnnotation> argAnnotations);
|
||||
|
||||
// The annotations collected so far.
|
||||
const ClassAnnotationMap &getAnnotationMap();
|
||||
|
@ -192,8 +191,6 @@ private:
|
|||
functionToMethodMap;
|
||||
};
|
||||
|
||||
void initClassAnnotatorBindings(py::module &m);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_H
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
//===- class_annotator_pybind.cpp -----------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "class_annotator_pybind.h"
|
||||
#include "class_annotator.h"
|
||||
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
||||
return dtype->scalar_type;
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "unsupported scalar type '" << obj << "'";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
|
||||
std::vector<ArgAnnotation> argAnnotations(pyArgAnnotations.size());
|
||||
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
|
||||
if (pyArgAnnotations[i].is_none()) {
|
||||
continue;
|
||||
}
|
||||
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
|
||||
auto shape = tuple[0];
|
||||
auto dtype = tuple[1];
|
||||
auto hasValueSemantics = tuple[2];
|
||||
if (!shape.is_none()) {
|
||||
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
|
||||
}
|
||||
if (!dtype.is_none()) {
|
||||
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
|
||||
}
|
||||
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
|
||||
};
|
||||
|
||||
return argAnnotations;
|
||||
}
|
||||
|
||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||
.def(py::init<>())
|
||||
.def("exportPath", &ClassAnnotator::exportPath)
|
||||
.def("exportNone", &ClassAnnotator::exportNone)
|
||||
.def("annotateArgs",
|
||||
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path, py::list argAnnotations) {
|
||||
cls_annotator.annotateArgs(rootClassType, path,
|
||||
getArgAnnotations(argAnnotations));
|
||||
})
|
||||
.def("__repr__", &ClassAnnotator::toString);
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===- module_builder.h -----------------------------------------*- C++ -*-===//
|
||||
//===- class_annotator_pybind.h ---------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -11,20 +11,14 @@
|
|||
// directly).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
||||
class mlir_diagnostic_emitted : public std::runtime_error {
|
||||
public:
|
||||
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
||||
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
||||
};
|
||||
|
||||
void initClassAnnotatorBindings(py::module &m);
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
|
@ -18,7 +18,6 @@
|
|||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include <memory>
|
||||
|
||||
#include "node_importer.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
||||
|
||||
#include "pybind.h"
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include "class_annotator.h"
|
||||
#include "class_annotator_pybind.h"
|
||||
#include "get_registered_ops.h"
|
||||
#include "module_builder.h"
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include <memory>
|
||||
|
||||
#include "class_annotator.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
|
|
@ -10,8 +10,6 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||
|
||||
#include "pybind.h"
|
||||
|
||||
#include "class_annotator.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
@ -20,6 +18,7 @@
|
|||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "torch-mlir-c/TorchOps.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
using Value = torch::jit::Value;
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
|
@ -22,6 +20,13 @@
|
|||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
||||
class mlir_diagnostic_emitted : public std::runtime_error {
|
||||
public:
|
||||
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
||||
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
||||
};
|
||||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
|
||||
/// type universe from the Python-based types modeled with `c10::Type`.
|
||||
|
|
Loading…
Reference in New Issue