Remove pybind deps from importer and annotator (#903)

* Remove pybind deps from importer and annotator
* Rename files to class_annotator_pybind.cpp/.h
pull/918/head snapshot-20220608.497
Tanyo Kwok 2022-06-08 10:12:05 +08:00 committed by GitHub
parent e1b38e74dd
commit bd53998da8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 86 additions and 71 deletions

View File

@ -12,6 +12,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(TorchMLIRJITIRImporter MODULE
class_annotator.cpp
class_annotator_pybind.cpp
get_registered_ops.cpp
function_importer.cpp
module_builder.cpp

View File

@ -11,8 +11,6 @@
#include <stdexcept>
#include "torch/csrc/Dtype.h"
using namespace torch_mlir;
//===----------------------------------------------------------------------===//
@ -150,21 +148,10 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
return *it->second;
}
static c10::ScalarType convertToC10ScalarType(py::object obj) {
if (THPDtype_Check(obj.ptr())) {
// Need reinterpret_cast, since no C++-level inheritance is involved.
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
return dtype->scalar_type;
}
std::stringstream ss;
ss << "unsupported scalar type '" << obj << "'";
throw std::invalid_argument(ss.str());
}
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
py::list pyArgAnnotations,
std::vector<ArgAnnotation> argAnnotations,
torch::jit::Function *function) {
if (pyArgAnnotations.size() != function->num_inputs()) {
if (argAnnotations.size() != function->num_inputs()) {
throw std::invalid_argument("Arg annotations should have one entry per "
"function parameter (including self).");
}
@ -172,29 +159,13 @@ static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
ArgAnnotation{});
}
std::vector<ArgAnnotation> &argAnnotations =
methodAnnotation.argAnnotations.value();
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
if (pyArgAnnotations[i].is_none()) {
continue;
}
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
auto shape = tuple[0];
auto dtype = tuple[1];
auto hasValueSemantics = tuple[2];
if (!shape.is_none()) {
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
}
if (!dtype.is_none()) {
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
}
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
};
methodAnnotation.argAnnotations = argAnnotations;
}
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
std::vector<std::string> path,
py::list argAnnotations) {
std::vector<ArgAnnotation> argAnnotations) {
if (path.size() == 0) {
throw std::invalid_argument("Empty annotated path. Can only annotate "
"shapes/dtypes of a method of a class.");
@ -331,12 +302,3 @@ std::string ClassAnnotator::toString() {
ss << "}\n";
return ss.str();
}
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
py::class_<ClassAnnotator>(m, "ClassAnnotator")
.def(py::init<>())
.def("exportPath", &ClassAnnotator::exportPath)
.def("exportNone", &ClassAnnotator::exportNone)
.def("annotateArgs", &ClassAnnotator::annotateArgs)
.def("__repr__", &ClassAnnotator::toString);
}

View File

@ -25,8 +25,6 @@
#include <torch/csrc/jit/ir/ir.h>
#include "pybind.h"
namespace torch_mlir {
// An annotation on a class's attribute (corresponds to a c10::ClassAttribute).
@ -162,7 +160,8 @@ public:
// These will be put into an `ArgAnnotation` struct -- see there for
// precise definitions of the promised semantics of each entry.
void annotateArgs(c10::ClassType &rootClassType,
std::vector<std::string> path, py::list argAnnotations);
std::vector<std::string> path,
std::vector<ArgAnnotation> argAnnotations);
// The annotations collected so far.
const ClassAnnotationMap &getAnnotationMap();
@ -192,8 +191,6 @@ private:
functionToMethodMap;
};
void initClassAnnotatorBindings(py::module &m);
} // namespace torch_mlir
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_H

View File

@ -0,0 +1,63 @@
//===- class_annotator_pybind.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "class_annotator_pybind.h"
#include "class_annotator.h"
#include <torch/csrc/Dtype.h>
#include <torch/csrc/utils/pybind.h>
using namespace torch_mlir;
static c10::ScalarType convertToC10ScalarType(py::object obj) {
if (THPDtype_Check(obj.ptr())) {
// Need reinterpret_cast, since no C++-level inheritance is involved.
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
return dtype->scalar_type;
}
std::stringstream ss;
ss << "unsupported scalar type '" << obj << "'";
throw std::invalid_argument(ss.str());
}
static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
std::vector<ArgAnnotation> argAnnotations(pyArgAnnotations.size());
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
if (pyArgAnnotations[i].is_none()) {
continue;
}
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
auto shape = tuple[0];
auto dtype = tuple[1];
auto hasValueSemantics = tuple[2];
if (!shape.is_none()) {
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
}
if (!dtype.is_none()) {
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
}
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
};
return argAnnotations;
}
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
py::class_<ClassAnnotator>(m, "ClassAnnotator")
.def(py::init<>())
.def("exportPath", &ClassAnnotator::exportPath)
.def("exportNone", &ClassAnnotator::exportNone)
.def("annotateArgs",
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
std::vector<std::string> path, py::list argAnnotations) {
cls_annotator.annotateArgs(rootClassType, path,
getArgAnnotations(argAnnotations));
})
.def("__repr__", &ClassAnnotator::toString);
}

View File

@ -1,4 +1,4 @@
//===- module_builder.h -----------------------------------------*- C++ -*-===//
//===- class_annotator_pybind.h ---------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -11,20 +11,14 @@
// directly).
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
#define TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch_mlir {
/// Thrown on failure when details are in MLIR emitted diagnostics.
class mlir_diagnostic_emitted : public std::runtime_error {
public:
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
};
void initClassAnnotatorBindings(py::module &m);
} // namespace torch_mlir
#endif // TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H

View File

@ -18,7 +18,6 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation torch_mlir::importJitFunctionAsFuncOp(

View File

@ -13,7 +13,6 @@
#include <memory>
#include "node_importer.h"
#include "pybind.h"
#include "mlir-c/IR.h"

View File

@ -15,7 +15,7 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
#define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
#include "pybind.h"
#include <torch/csrc/utils/pybind.h>
namespace torch_mlir {

View File

@ -11,7 +11,7 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include "class_annotator.h"
#include "class_annotator_pybind.h"
#include "get_registered_ops.h"
#include "module_builder.h"

View File

@ -13,7 +13,6 @@
#include <memory>
#include "class_annotator.h"
#include "pybind.h"
#include "mlir-c/IR.h"

View File

@ -10,8 +10,6 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
#include "pybind.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"
@ -20,6 +18,7 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/utils/pybind.h>
namespace torch_mlir {

View File

@ -22,7 +22,6 @@
#include "torch-mlir-c/TorchOps.h"
#include "torch-mlir-c/TorchTypes.h"
namespace py = pybind11;
using namespace torch_mlir;
using Value = torch::jit::Value;

View File

@ -12,8 +12,6 @@
#include <memory>
#include "pybind.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>

View File

@ -12,8 +12,6 @@
#include <memory>
#include "pybind.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
@ -22,6 +20,13 @@
namespace torch_mlir {
/// Thrown on failure when details are in MLIR emitted diagnostics.
class mlir_diagnostic_emitted : public std::runtime_error {
public:
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
};
/// Gets a corresponding MlirType for the Torch ScalarType.
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
/// type universe from the Python-based types modeled with `c10::Type`.