diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt index 4a928401e..5545a32bd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt @@ -12,6 +12,7 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(TorchMLIRJITIRImporter MODULE class_annotator.cpp + class_annotator_pybind.cpp get_registered_ops.cpp function_importer.cpp module_builder.cpp diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp index b14005b9f..c02f014c7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp @@ -11,8 +11,6 @@ #include -#include "torch/csrc/Dtype.h" - using namespace torch_mlir; //===----------------------------------------------------------------------===// @@ -150,21 +148,10 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { return *it->second; } -static c10::ScalarType convertToC10ScalarType(py::object obj) { - if (THPDtype_Check(obj.ptr())) { - // Need reinterpret_cast, since no C++-level inheritance is involved. - THPDtype *dtype = reinterpret_cast(obj.ptr()); - return dtype->scalar_type; - } - std::stringstream ss; - ss << "unsupported scalar type '" << obj << "'"; - throw std::invalid_argument(ss.str()); -} - static void fillArgAnnotations(MethodAnnotation &methodAnnotation, - py::list pyArgAnnotations, + std::vector argAnnotations, torch::jit::Function *function) { - if (pyArgAnnotations.size() != function->num_inputs()) { + if (argAnnotations.size() != function->num_inputs()) { throw std::invalid_argument("Arg annotations should have one entry per " "function parameter (including self)."); } @@ -172,29 +159,13 @@ static void fillArgAnnotations(MethodAnnotation &methodAnnotation, methodAnnotation.argAnnotations.emplace(function->num_inputs(), ArgAnnotation{}); } - std::vector &argAnnotations = - methodAnnotation.argAnnotations.value(); - for (int i = 0, e = argAnnotations.size(); i != e; i++) { - if (pyArgAnnotations[i].is_none()) { - continue; - } - auto tuple = py::cast(pyArgAnnotations[i]); - auto shape = tuple[0]; - auto dtype = tuple[1]; - auto hasValueSemantics = tuple[2]; - if (!shape.is_none()) { - argAnnotations[i].shape = py::cast>(shape); - } - if (!dtype.is_none()) { - argAnnotations[i].dtype = convertToC10ScalarType(dtype); - } - argAnnotations[i].hasValueSemantics = py::cast(hasValueSemantics); - }; + + methodAnnotation.argAnnotations = argAnnotations; } void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, std::vector path, - py::list argAnnotations) { + std::vector argAnnotations) { if (path.size() == 0) { throw std::invalid_argument("Empty annotated path. Can only annotate " "shapes/dtypes of a method of a class."); @@ -331,12 +302,3 @@ std::string ClassAnnotator::toString() { ss << "}\n"; return ss.str(); } - -void torch_mlir::initClassAnnotatorBindings(py::module &m) { - py::class_(m, "ClassAnnotator") - .def(py::init<>()) - .def("exportPath", &ClassAnnotator::exportPath) - .def("exportNone", &ClassAnnotator::exportNone) - .def("annotateArgs", &ClassAnnotator::annotateArgs) - .def("__repr__", &ClassAnnotator::toString); -} diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h index 558ae49eb..03f39c1b1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h @@ -25,8 +25,6 @@ #include -#include "pybind.h" - namespace torch_mlir { // An annotation on a class's attribute (corresponds to a c10::ClassAttribute). @@ -162,7 +160,8 @@ public: // These will be put into an `ArgAnnotation` struct -- see there for // precise definitions of the promised semantics of each entry. void annotateArgs(c10::ClassType &rootClassType, - std::vector path, py::list argAnnotations); + std::vector path, + std::vector argAnnotations); // The annotations collected so far. const ClassAnnotationMap &getAnnotationMap(); @@ -192,8 +191,6 @@ private: functionToMethodMap; }; -void initClassAnnotatorBindings(py::module &m); - } // namespace torch_mlir #endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_H diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp new file mode 100644 index 000000000..7d8525209 --- /dev/null +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp @@ -0,0 +1,63 @@ +//===- class_annotator_pybind.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "class_annotator_pybind.h" +#include "class_annotator.h" + +#include +#include + +using namespace torch_mlir; + +static c10::ScalarType convertToC10ScalarType(py::object obj) { + if (THPDtype_Check(obj.ptr())) { + // Need reinterpret_cast, since no C++-level inheritance is involved. + THPDtype *dtype = reinterpret_cast(obj.ptr()); + return dtype->scalar_type; + } + std::stringstream ss; + ss << "unsupported scalar type '" << obj << "'"; + throw std::invalid_argument(ss.str()); +} + +static std::vector getArgAnnotations(py::list pyArgAnnotations) { + std::vector argAnnotations(pyArgAnnotations.size()); + for (int i = 0, e = argAnnotations.size(); i != e; i++) { + if (pyArgAnnotations[i].is_none()) { + continue; + } + auto tuple = py::cast(pyArgAnnotations[i]); + auto shape = tuple[0]; + auto dtype = tuple[1]; + auto hasValueSemantics = tuple[2]; + if (!shape.is_none()) { + argAnnotations[i].shape = py::cast>(shape); + } + if (!dtype.is_none()) { + argAnnotations[i].dtype = convertToC10ScalarType(dtype); + } + argAnnotations[i].hasValueSemantics = py::cast(hasValueSemantics); + }; + + return argAnnotations; +} + +void torch_mlir::initClassAnnotatorBindings(py::module &m) { + py::class_(m, "ClassAnnotator") + .def(py::init<>()) + .def("exportPath", &ClassAnnotator::exportPath) + .def("exportNone", &ClassAnnotator::exportNone) + .def("annotateArgs", + [&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType, + std::vector path, py::list argAnnotations) { + cls_annotator.annotateArgs(rootClassType, path, + getArgAnnotations(argAnnotations)); + }) + .def("__repr__", &ClassAnnotator::toString); +} diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/pybind.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h similarity index 62% rename from python/torch_mlir/dialects/torch/importer/jit_ir/csrc/pybind.h rename to python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h index dc191516d..7dc42b86a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/pybind.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h @@ -1,4 +1,4 @@ -//===- module_builder.h -----------------------------------------*- C++ -*-===// +//===- class_annotator_pybind.h ---------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,20 +11,14 @@ // directly). //===----------------------------------------------------------------------===// -#ifndef TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H +#ifndef TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H #define TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H #include +namespace py = pybind11; namespace torch_mlir { - -/// Thrown on failure when details are in MLIR emitted diagnostics. -class mlir_diagnostic_emitted : public std::runtime_error { -public: - mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} - mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} -}; - +void initClassAnnotatorBindings(py::module &m); } // namespace torch_mlir -#endif // TORCHMLIRJITIRIMPORTER_CSRC_PYBIND_H +#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index dcda400c7..0ec368903 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -18,7 +18,6 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -namespace py = pybind11; using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h index c749c26bc..f9f5b10f7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h @@ -13,7 +13,6 @@ #include #include "node_importer.h" -#include "pybind.h" #include "mlir-c/IR.h" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h index 0fe9471fa..ec336878c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h @@ -15,7 +15,7 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H #define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H -#include "pybind.h" +#include namespace torch_mlir { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp index 9638bc2a5..7506e4c9d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp @@ -11,7 +11,7 @@ #include -#include "class_annotator.h" +#include "class_annotator_pybind.h" #include "get_registered_ops.h" #include "module_builder.h" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h index d6ab53463..36e07cc2c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h @@ -13,7 +13,6 @@ #include #include "class_annotator.h" -#include "pybind.h" #include "mlir-c/IR.h" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h index 39944239a..95c85536d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h @@ -10,8 +10,6 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H #define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H -#include "pybind.h" - #include "class_annotator.h" #include "mlir-c/IR.h" @@ -20,6 +18,7 @@ #include #include #include +#include namespace torch_mlir { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index a879d46ad..2d9805b4f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -22,7 +22,6 @@ #include "torch-mlir-c/TorchOps.h" #include "torch-mlir-c/TorchTypes.h" -namespace py = pybind11; using namespace torch_mlir; using Value = torch::jit::Value; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h index c4893fdb9..b05a37cfd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h @@ -12,8 +12,6 @@ #include -#include "pybind.h" - #include "mlir-c/IR.h" #include diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index 328be1291..e3ff4f45d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -12,8 +12,6 @@ #include -#include "pybind.h" - #include "mlir-c/IR.h" #include @@ -22,6 +20,13 @@ namespace torch_mlir { +/// Thrown on failure when details are in MLIR emitted diagnostics. +class mlir_diagnostic_emitted : public std::runtime_error { +public: + mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} + mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} +}; + /// Gets a corresponding MlirType for the Torch ScalarType. /// `c10::`ScalarType` is used to represent tensor dtypes, and is a different /// type universe from the Python-based types modeled with `c10::Type`.