mirror of https://github.com/llvm/torch-mlir
Remove pybind deps from importer and annotator (#903)
* Remove pybind deps from importer and annotator * Rename files to class_annotator_pybind.cpp/.hpull/918/head snapshot-20220608.497
parent
e1b38e74dd
commit
bd53998da8
|
@ -12,6 +12,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
add_library(TorchMLIRJITIRImporter MODULE
|
add_library(TorchMLIRJITIRImporter MODULE
|
||||||
class_annotator.cpp
|
class_annotator.cpp
|
||||||
|
class_annotator_pybind.cpp
|
||||||
get_registered_ops.cpp
|
get_registered_ops.cpp
|
||||||
function_importer.cpp
|
function_importer.cpp
|
||||||
module_builder.cpp
|
module_builder.cpp
|
||||||
|
|
|
@ -11,8 +11,6 @@
|
||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "torch/csrc/Dtype.h"
|
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -150,21 +148,10 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
||||||
return *it->second;
|
return *it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
|
||||||
if (THPDtype_Check(obj.ptr())) {
|
|
||||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
|
||||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
|
||||||
return dtype->scalar_type;
|
|
||||||
}
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "unsupported scalar type '" << obj << "'";
|
|
||||||
throw std::invalid_argument(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||||
py::list pyArgAnnotations,
|
std::vector<ArgAnnotation> argAnnotations,
|
||||||
torch::jit::Function *function) {
|
torch::jit::Function *function) {
|
||||||
if (pyArgAnnotations.size() != function->num_inputs()) {
|
if (argAnnotations.size() != function->num_inputs()) {
|
||||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
throw std::invalid_argument("Arg annotations should have one entry per "
|
||||||
"function parameter (including self).");
|
"function parameter (including self).");
|
||||||
}
|
}
|
||||||
|
@ -172,29 +159,13 @@ static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||||
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
||||||
ArgAnnotation{});
|
ArgAnnotation{});
|
||||||
}
|
}
|
||||||
std::vector<ArgAnnotation> &argAnnotations =
|
|
||||||
methodAnnotation.argAnnotations.value();
|
methodAnnotation.argAnnotations = argAnnotations;
|
||||||
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
|
|
||||||
if (pyArgAnnotations[i].is_none()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
|
|
||||||
auto shape = tuple[0];
|
|
||||||
auto dtype = tuple[1];
|
|
||||||
auto hasValueSemantics = tuple[2];
|
|
||||||
if (!shape.is_none()) {
|
|
||||||
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
|
|
||||||
}
|
|
||||||
if (!dtype.is_none()) {
|
|
||||||
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
|
|
||||||
}
|
|
||||||
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
||||||
std::vector<std::string> path,
|
std::vector<std::string> path,
|
||||||
py::list argAnnotations) {
|
std::vector<ArgAnnotation> argAnnotations) {
|
||||||
if (path.size() == 0) {
|
if (path.size() == 0) {
|
||||||
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
||||||
"shapes/dtypes of a method of a class.");
|
"shapes/dtypes of a method of a class.");
|
||||||
|
@ -331,12 +302,3 @@ std::string ClassAnnotator::toString() {
|
||||||
ss << "}\n";
|
ss << "}\n";
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
|
||||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
|
||||||
.def(py::init<>())
|
|
||||||
.def("exportPath", &ClassAnnotator::exportPath)
|
|
||||||
.def("exportNone", &ClassAnnotator::exportNone)
|
|
||||||
.def("annotateArgs", &ClassAnnotator::annotateArgs)
|
|
||||||
.def("__repr__", &ClassAnnotator::toString);
|
|
||||||
}
|
|
||||||
|
|
|
@ -25,8 +25,6 @@
|
||||||
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
// An annotation on a class's attribute (corresponds to a c10::ClassAttribute).
|
// An annotation on a class's attribute (corresponds to a c10::ClassAttribute).
|
||||||
|
@ -162,7 +160,8 @@ public:
|
||||||
// These will be put into an `ArgAnnotation` struct -- see there for
|
// These will be put into an `ArgAnnotation` struct -- see there for
|
||||||
// precise definitions of the promised semantics of each entry.
|
// precise definitions of the promised semantics of each entry.
|
||||||
void annotateArgs(c10::ClassType &rootClassType,
|
void annotateArgs(c10::ClassType &rootClassType,
|
||||||
std::vector<std::string> path, py::list argAnnotations);
|
std::vector<std::string> path,
|
||||||
|
std::vector<ArgAnnotation> argAnnotations);
|
||||||
|
|
||||||
// The annotations collected so far.
|
// The annotations collected so far.
|
||||||
const ClassAnnotationMap &getAnnotationMap();
|
const ClassAnnotationMap &getAnnotationMap();
|
||||||
|
@ -192,8 +191,6 @@ private:
|
||||||
functionToMethodMap;
|
functionToMethodMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
void initClassAnnotatorBindings(py::module &m);
|
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_H
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
//===- class_annotator_pybind.cpp -----------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "class_annotator_pybind.h"
|
||||||
|
#include "class_annotator.h"
|
||||||
|
|
||||||
|
#include <torch/csrc/Dtype.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||||
|
if (THPDtype_Check(obj.ptr())) {
|
||||||
|
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||||
|
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
||||||
|
return dtype->scalar_type;
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "unsupported scalar type '" << obj << "'";
|
||||||
|
throw std::invalid_argument(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
|
||||||
|
std::vector<ArgAnnotation> argAnnotations(pyArgAnnotations.size());
|
||||||
|
for (int i = 0, e = argAnnotations.size(); i != e; i++) {
|
||||||
|
if (pyArgAnnotations[i].is_none()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto tuple = py::cast<py::tuple>(pyArgAnnotations[i]);
|
||||||
|
auto shape = tuple[0];
|
||||||
|
auto dtype = tuple[1];
|
||||||
|
auto hasValueSemantics = tuple[2];
|
||||||
|
if (!shape.is_none()) {
|
||||||
|
argAnnotations[i].shape = py::cast<std::vector<int64_t>>(shape);
|
||||||
|
}
|
||||||
|
if (!dtype.is_none()) {
|
||||||
|
argAnnotations[i].dtype = convertToC10ScalarType(dtype);
|
||||||
|
}
|
||||||
|
argAnnotations[i].hasValueSemantics = py::cast<bool>(hasValueSemantics);
|
||||||
|
};
|
||||||
|
|
||||||
|
return argAnnotations;
|
||||||
|
}
|
||||||
|
|
||||||
|
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
||||||
|
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def("exportPath", &ClassAnnotator::exportPath)
|
||||||
|
.def("exportNone", &ClassAnnotator::exportNone)
|
||||||
|
.def("annotateArgs",
|
||||||
|
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
|
||||||
|
std::vector<std::string> path, py::list argAnnotations) {
|
||||||
|
cls_annotator.annotateArgs(rootClassType, path,
|
||||||
|
getArgAnnotations(argAnnotations));
|
||||||
|
})
|
||||||
|
.def("__repr__", &ClassAnnotator::toString);
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
//===- module_builder.h -----------------------------------------*- C++ -*-===//
|
//===- class_annotator_pybind.h ---------------------------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -11,20 +11,14 @@
|
||||||
// directly).
|
// directly).
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
||||||
|
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
void initClassAnnotatorBindings(py::module &m);
|
||||||
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
|
||||||
class mlir_diagnostic_emitted : public std::runtime_error {
|
|
||||||
public:
|
|
||||||
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
|
||||||
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
|
@ -18,7 +18,6 @@
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "node_importer.h"
|
#include "node_importer.h"
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H
|
||||||
|
|
||||||
#include "pybind.h"
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "class_annotator_pybind.h"
|
||||||
#include "get_registered_ops.h"
|
#include "get_registered_ops.h"
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "class_annotator.h"
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,6 @@
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
|
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "class_annotator.h"
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
@ -20,6 +18,7 @@
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include "torch-mlir-c/TorchOps.h"
|
#include "torch-mlir-c/TorchOps.h"
|
||||||
#include "torch-mlir-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
using Value = torch::jit::Value;
|
using Value = torch::jit::Value;
|
||||||
|
|
|
@ -12,8 +12,6 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
|
|
@ -12,8 +12,6 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "pybind.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
@ -22,6 +20,13 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
||||||
|
class mlir_diagnostic_emitted : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
||||||
|
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
||||||
|
};
|
||||||
|
|
||||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||||
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
|
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
|
||||||
/// type universe from the Python-based types modeled with `c10::Type`.
|
/// type universe from the Python-based types modeled with `c10::Type`.
|
||||||
|
|
Loading…
Reference in New Issue