From 606dc45896f13146e0dab5f4ffdf55de1e17dc85 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 18 Nov 2023 17:56:00 -0800 Subject: [PATCH] Get LTC building. --- .gitignore | 2 +- build_tools/autogen_ltc_backend.py | 11 +- docs/ltc_backend.md | 6 +- .../ltc/csrc/base_lazy_backend/CMakeLists.txt | 14 +- .../mlir_lowering_context.cpp | 2 +- projects/pt1/python/CMakeLists.txt | 1 + .../jit_ir_importer}/CMakeLists.txt | 4 + .../jit_ir_importer}/class_annotator.cpp | 106 ++++----- .../jit_ir_importer}/class_annotator.h | 34 +-- .../class_annotator_pybind.cpp | 17 +- .../jit_ir_importer}/class_annotator_pybind.h | 2 +- .../jit_ir_importer}/function_importer.cpp | 13 +- .../jit_ir_importer}/function_importer.h | 4 +- .../jit_ir_importer}/get_registered_ops.cpp | 16 +- .../jit_ir_importer}/get_registered_ops.h | 2 +- .../jit_ir_importer}/import_options.h | 0 .../import_options_pybind.cpp | 12 +- .../jit_ir_importer}/import_options_pybind.h | 2 +- .../jit_ir_importer}/init_python_bindings.cpp | 0 .../jit_ir_importer}/ivalue_importer.cpp | 126 ++++++----- .../jit_ir_importer}/ivalue_importer.h | 6 +- .../jit_ir_importer}/mlir_utils.h | 60 ++--- .../jit_ir_importer}/module_builder.cpp | 65 +++--- .../jit_ir_importer}/module_builder.h | 13 +- .../jit_ir_importer}/node_importer.cpp | 211 +++++++++--------- .../jit_ir_importer}/node_importer.h | 4 +- .../jit_ir_importer}/torch_to_mlir_utils.cpp | 187 ++++++++-------- .../jit_ir_importer}/torch_to_mlir_utils.h | 46 ++-- .../reference_lazy_backend/backend_impl.cpp | 12 +- .../reference_lazy_backend_pybind.cpp | 8 +- .../torch_mlir/jit_ir_importer/CMakeLists.txt | 2 - 31 files changed, 508 insertions(+), 480 deletions(-) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/CMakeLists.txt (92%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/class_annotator.cpp (71%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/class_annotator.h (88%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/class_annotator_pybind.cpp (80%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/class_annotator_pybind.h (95%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/function_importer.cpp (88%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/function_importer.h (94%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/get_registered_ops.cpp (89%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/get_registered_ops.h (94%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/import_options.h (100%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/import_options_pybind.cpp (65%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/import_options_pybind.h (92%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/init_python_bindings.cpp (100%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/ivalue_importer.cpp (85%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/ivalue_importer.h (83%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/mlir_utils.h (54%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/module_builder.cpp (76%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/module_builder.h (85%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/node_importer.cpp (67%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/node_importer.h (94%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/torch_to_mlir_utils.cpp (74%) rename projects/pt1/python/torch_mlir/{jit_ir_importer/csrc => csrc/jit_ir_importer}/torch_to_mlir_utils.h (61%) diff --git a/.gitignore b/.gitignore index 6b76bc3ea..5c4074289 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,7 @@ __pycache__ bazel-* # Autogenerated files -/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated +/projects/ltc/csrc/base_lazy_backend/generated #Docker builds build_oot/ diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 02ac0eff0..40a64c1c1 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -29,7 +29,6 @@ if not TORCH_INCLUDE_DIR.is_dir(): TORCH_INCLUDE_DIR = TORCH_DIR TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent -TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1" def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -114,12 +113,12 @@ class GenTorchMlirLTC: self.binary_dir = Path(binary_dir) assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}" self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml") - self.backend_path = TORCH_MLIR_PT1_DIR.joinpath( - "python", "torch_mlir", "csrc", "base_lazy_backend" + self.backend_path = TORCH_MLIR_DIR.joinpath( + "projects", "ltc", "csrc", "base_lazy_backend" ) assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}" self.generated_path = self.binary_dir.joinpath( - "projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated" + "projects", "ltc", "csrc", "base_lazy_backend", "generated" ) self.generated_path.mkdir(parents=True, exist_ok=True) @@ -415,7 +414,7 @@ class GenTorchMlirLTC: // for ops that dont have a corresponding structured kernel or shape definition #include "shape_inference.h" - #include "torch_mlir/csrc/base_lazy_backend/utils/exception.h" + #include "base_lazy_backend/utils/exception.h" namespace torch {{ namespace lazy {{ {} @@ -467,7 +466,7 @@ class GenTorchMlirLTC: node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), tensor_class=self.tensor_class, - tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", + tensor_class_hdr="base_lazy_backend/tensor.h", create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), lazy_ir_generator=GenMlirLazyIr, diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md index ae3cc887c..b01775428 100644 --- a/docs/ltc_backend.md +++ b/docs/ltc_backend.md @@ -12,7 +12,7 @@ [Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR. After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation. -LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. +LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. Implementations based on this abstract class will be able to specify their own compile and execution workflows. Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend). @@ -27,7 +27,7 @@ View examples [here](ltc_examples.md). - The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td), excluding those explicitly blacklisted in the YAML file -### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated)) +### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated)) Generated files are created in this directory, which is ignored by version control. - `LazyIr.h` @@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr - `shape_inference.{cpp,h}` - Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions -### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend)) +### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend)) - `backend_impl.{cpp,h}` - Base LTC backend to setup Torch-MLIR lowering context diff --git a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt index 1a651faed..eee3044f0 100644 --- a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt +++ b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt @@ -56,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED utils/tensor_utils.cpp ) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) +# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/..."). +# Add both the source and generated include directories. +target_include_directories(torch_mlir_ltc_backend PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_BINARY_DIR}/.. +) add_dependencies(torch_mlir_ltc_backend TorchMLIRJITIRImporter @@ -88,13 +94,13 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/) add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/) add_custom_command( @@ -105,7 +111,7 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/) add_custom_command( @@ -116,5 +122,5 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/) diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp index fd93d4d2b..7e6f40c5c 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -21,8 +21,8 @@ #include "mlir-c/IR.h" #include "mlir-c/Pass.h" -#include "../../jit_ir_importer/csrc/function_importer.h" #include "backend_impl.h" +#include "jit_ir_importer/function_importer.h" #include "mlir_lowering_context.h" #include "mlir_node.h" #include "utils/debug.h" diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index 70d21301b..23dd5be40 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -100,6 +100,7 @@ endif() if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) add_subdirectory(torch_mlir/jit_ir_importer) + add_subdirectory(torch_mlir/csrc/jit_ir_importer) add_subdirectory(torch_mlir_e2e_test) endif() diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/CMakeLists.txt b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt similarity index 92% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/CMakeLists.txt rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt index 10870a7b0..bcf1ec89d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt @@ -12,6 +12,10 @@ target_link_libraries(TorchMLIRJITIRImporter TorchMLIRAggregateCAPI ${TORCH_LIBRARIES} ) +# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...") +target_include_directories(TorchMLIRJITIRImporter PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. +) set_target_properties(TorchMLIRJITIRImporter PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" OUTPUT_NAME lib_jit_ir_importer diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp similarity index 71% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp index b144e946b..9f936486f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp @@ -18,8 +18,8 @@ using namespace torch_mlir; //===----------------------------------------------------------------------===// // Prefix every line of `s` with `linePrefix`. -static std::string indentString(const std::string &linePrefix, - const std::string &s) { +static std::string +indentString(const std::string& linePrefix, const std::string& s) { std::stringstream is(s); std::stringstream os; std::string line; @@ -39,26 +39,28 @@ ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType) methodAnnotations.resize(classType->methods().size()); } -std::vector &ClassAnnotation::getAttributeAnnotations() { +std::vector& ClassAnnotation::getAttributeAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // // We can't easily guard against attributes being removed and // then other attributes being added, or types changed, etc. without // effectively mirroring the entire ClassType. - assert(attributeAnnotations.size() == classType->getAttributes().size() && - "annotations out of sync. class has been mutated"); + assert( + attributeAnnotations.size() == classType->getAttributes().size() && + "annotations out of sync. class has been mutated"); return attributeAnnotations; } -std::vector &ClassAnnotation::getMethodAnnotations() { +std::vector& ClassAnnotation::getMethodAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // // We can't easily guard against methods being removed, added, or changed. - assert(methodAnnotations.size() == classType->methods().size() && - "annotations out of sync. class has been mutated"); + assert( + methodAnnotations.size() == classType->methods().size() && + "annotations out of sync. class has been mutated"); return methodAnnotations; } @@ -67,17 +69,17 @@ std::vector &ClassAnnotation::getMethodAnnotations() { // ClassAnnotator //===----------------------------------------------------------------------===// -static void exportNoneRecurse(ClassAnnotator &classAnnotator, - c10::ClassType *classType) { - ClassAnnotation &classAnnotation = +static void +exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) { + ClassAnnotation& classAnnotation = classAnnotator.getOrCreateClassAnnotation(classType); - for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) { + for (auto& attributeAnnotation : classAnnotation.getAttributeAnnotations()) { attributeAnnotation.isExported = false; } - for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) { + for (auto& methodAnnotation : classAnnotation.getMethodAnnotations()) { methodAnnotation.isExported = false; } - for (auto &classAttribute : classType->getAttributes()) { + for (auto& classAttribute : classType->getAttributes()) { if (auto childClassType = classAttribute.getType()->cast()) { exportNoneRecurse(classAnnotator, childClassType.get()); @@ -85,20 +87,20 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator, } } -void ClassAnnotator::exportNone(c10::ClassType &rootClassType) { +void ClassAnnotator::exportNone(c10::ClassType& rootClassType) { exportNoneRecurse(*this, &rootClassType); } -void ClassAnnotator::exportPath(c10::ClassType &rootClassType, - std::vector exportedPath) { +void ClassAnnotator::exportPath( + c10::ClassType& rootClassType, std::vector exportedPath) { if (exportedPath.size() == 0) { throw std::invalid_argument( "Empty exported path. Can only export a property of a class."); } - c10::ClassType *classType = - getClassAtPath(&rootClassType, c10::ArrayRef(exportedPath) - .slice(0, exportedPath.size() - 1) - .vec()); + c10::ClassType* classType = getClassAtPath( + &rootClassType, c10::ArrayRef(exportedPath) + .slice(0, exportedPath.size() - 1) + .vec()); if (!classType->findAttribute(exportedPath.back()) && !classType->findMethod(exportedPath.back())) { @@ -108,10 +110,10 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType, << exportedPath.back() << "'"; throw std::invalid_argument(ss.str()); } - ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); - std::vector &attributeAnnotations = + ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); + std::vector& attributeAnnotations = classAnnotation.getAttributeAnnotations(); - const std::vector &classAttributes = + const std::vector& classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { if (classAttributes[i].getName() == exportedPath.back()) { @@ -119,9 +121,9 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType, } } - std::vector &methodAnnotations = + std::vector& methodAnnotations = classAnnotation.getMethodAnnotations(); - const std::vector &methods = classType->methods(); + const std::vector& methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { if (methods[i]->name() == exportedPath.back()) { methodAnnotations[i].isExported = true; @@ -129,12 +131,12 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType, } } -const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() { +const ClassAnnotationMap& ClassAnnotator::getAnnotationMap() { return classAnnotations; } -ClassAnnotation & -ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { +ClassAnnotation& +ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) { auto className = classType->name()->qualifiedName(); auto it = classAnnotations.find(className); if (it == classAnnotations.end()) { @@ -149,39 +151,39 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { return *it->second; } -static void fillArgAnnotations(MethodAnnotation &methodAnnotation, - std::vector argAnnotations, - torch::jit::Function *function) { +static void fillArgAnnotations( + MethodAnnotation& methodAnnotation, + std::vector argAnnotations, torch::jit::Function* function) { if (argAnnotations.size() != function->num_inputs()) { throw std::invalid_argument("Arg annotations should have one entry per " "function parameter (including self)."); } if (!methodAnnotation.argAnnotations.has_value()) { - methodAnnotation.argAnnotations.emplace(function->num_inputs(), - ArgAnnotation{}); + methodAnnotation.argAnnotations.emplace( + function->num_inputs(), ArgAnnotation{}); } methodAnnotation.argAnnotations = argAnnotations; } -void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, - std::vector path, - std::vector argAnnotations) { +void ClassAnnotator::annotateArgs( + c10::ClassType& rootClassType, std::vector path, + std::vector argAnnotations) { if (path.size() == 0) { throw std::invalid_argument("Empty annotated path. Can only annotate " "shapes/dtypes of a method of a class."); } - c10::ClassType *classType = getClassAtPath( + c10::ClassType* classType = getClassAtPath( &rootClassType, c10::ArrayRef(path).slice(0, path.size() - 1).vec()); // Throw error if no method on the class of the specified name. - torch::jit::Function *function = &classType->getMethod(path.back()); + torch::jit::Function* function = &classType->getMethod(path.back()); - ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); - std::vector &methodAnnotations = + ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); + std::vector& methodAnnotations = classAnnotation.getMethodAnnotations(); - const std::vector &methods = classType->methods(); + const std::vector& methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { if (methods[i]->name() == path.back()) { fillArgAnnotations(methodAnnotations[i], argAnnotations, function); @@ -191,9 +193,9 @@ void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, return; } -c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType, - std::vector path) { - c10::ClassType *classType = rootClassType; +c10::ClassType* ClassAnnotator::getClassAtPath( + c10::ClassType* rootClassType, std::vector path) { + c10::ClassType* classType = rootClassType; // Reverse so that pop_back gives us the initial atoms first. std::reverse(path.begin(), path.end()); while (!path.empty()) { @@ -215,8 +217,8 @@ c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType, //===----------------------------------------------------------------------===// // Helper methods //===----------------------------------------------------------------------===// -MethodAnnotation * -ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) { +MethodAnnotation* +ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) { auto it = functionToMethodMap.find(function); if (it == functionToMethodMap.end()) { return nullptr; @@ -228,7 +230,7 @@ ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) { // toString methods //===----------------------------------------------------------------------===// -std::string AttributeAnnotation::toString(const std::string &name) { +std::string AttributeAnnotation::toString(const std::string& name) { std::stringstream ss; ss << "AttributeAnnotation('" << name << "') {\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n"; @@ -259,7 +261,7 @@ std::string ArgAnnotation::toString(int argIndex) { return ss.str(); } -std::string MethodAnnotation::toString(const std::string &name) { +std::string MethodAnnotation::toString(const std::string& name) { std::stringstream ss; ss << "MethodAnnotation('" << name << "') {\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n"; @@ -280,13 +282,13 @@ std::string ClassAnnotation::toString() { std::stringstream ss; ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n"; - const std::vector &classAttributes = + const std::vector& classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { ss << indentString( " ", attributeAnnotations[i].toString(classAttributes[i].getName())); } - const std::vector &methods = classType->methods(); + const std::vector& methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name())); } @@ -297,7 +299,7 @@ std::string ClassAnnotation::toString() { std::string ClassAnnotator::toString() { std::stringstream ss; ss << "ClassAnnotator {\n"; - for (auto &p : classAnnotations) { + for (auto& p : classAnnotations) { ss << indentString(" ", p.second->toString()); } ss << "}\n"; diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h similarity index 88% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h index 0a0815eab..11aa4e434 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h @@ -34,7 +34,7 @@ struct AttributeAnnotation { // can be externally accessed. bool isExported = true; - std::string toString(const std::string &name); + std::string toString(const std::string& name); }; // An annotation of an argument of a method. @@ -80,7 +80,7 @@ struct MethodAnnotation { // large printout of the default ArgAnnotation for every method. c10::optional> argAnnotations; - std::string toString(const std::string &name); + std::string toString(const std::string& name); }; // Annotations on a c10::ClassType. @@ -107,10 +107,10 @@ public: // Get the attribute annotations. // The length and order is the same as `classType->getAttributes()`. - std::vector &getAttributeAnnotations(); + std::vector& getAttributeAnnotations(); // Get the method annotations. // The length and order is the same as `classType->methods()`. - std::vector &getMethodAnnotations(); + std::vector& getMethodAnnotations(); std::string toString(); @@ -141,14 +141,14 @@ public: // For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should // have a submodule `a` and that submodule should have a method or attribute // `b`. - void exportPath(c10::ClassType &rootClassType, - std::vector exportedPath); + void exportPath( + c10::ClassType& rootClassType, std::vector exportedPath); // Mark everything as not-exported. // // This is kind of useless by itself, but together with `exportPath` allows // exporting a subset of known names out of a larger collection of unknown // names. - void exportNone(c10::ClassType &rootClassType); + void exportNone(c10::ClassType& rootClassType); // Annotate shapes and dtypes of the arguments of a method at path `path` from // `rootClassType`. @@ -159,23 +159,23 @@ public: // a "has value semantics" boolean. // These will be put into an `ArgAnnotation` struct -- see there for // precise definitions of the promised semantics of each entry. - void annotateArgs(c10::ClassType &rootClassType, - std::vector path, - std::vector argAnnotations); + void annotateArgs( + c10::ClassType& rootClassType, std::vector path, + std::vector argAnnotations); // The annotations collected so far. - const ClassAnnotationMap &getAnnotationMap(); + const ClassAnnotationMap& getAnnotationMap(); // Get the ClassAnnotation corresponding to `classType`. - ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType); + ClassAnnotation& getOrCreateClassAnnotation(c10::ClassType* classType); // Helper to find the MethodAnnotation corresponding to a // torch::jit::Function, or null if not found. // // Users could in principle scan all annotations to find this, but it's more // efficient to maintain the reverse mapping directly. - MethodAnnotation * - getMethodAnnotationForFunction(torch::jit::Function *function); + MethodAnnotation* + getMethodAnnotationForFunction(torch::jit::Function* function); std::string toString(); @@ -183,11 +183,11 @@ private: // Traverse `path` starting from `rootClassType` to find the ClassType // of a presumed nested submodule. Throw an error if there is no such // submodule. - c10::ClassType *getClassAtPath(c10::ClassType *rootClassType, - std::vector path); + c10::ClassType* + getClassAtPath(c10::ClassType* rootClassType, std::vector path); ClassAnnotationMap classAnnotations; // Reverse mapping used to service getMethodAnnotationForFunction. - std::unordered_map + std::unordered_map functionToMethodMap; }; diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp similarity index 80% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp index 7d8525209..2e5296820 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp @@ -18,7 +18,7 @@ using namespace torch_mlir; static c10::ScalarType convertToC10ScalarType(py::object obj) { if (THPDtype_Check(obj.ptr())) { // Need reinterpret_cast, since no C++-level inheritance is involved. - THPDtype *dtype = reinterpret_cast(obj.ptr()); + THPDtype* dtype = reinterpret_cast(obj.ptr()); return dtype->scalar_type; } std::stringstream ss; @@ -48,16 +48,17 @@ static std::vector getArgAnnotations(py::list pyArgAnnotations) { return argAnnotations; } -void torch_mlir::initClassAnnotatorBindings(py::module &m) { +void torch_mlir::initClassAnnotatorBindings(py::module& m) { py::class_(m, "ClassAnnotator") .def(py::init<>()) .def("exportPath", &ClassAnnotator::exportPath) .def("exportNone", &ClassAnnotator::exportNone) - .def("annotateArgs", - [&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType, - std::vector path, py::list argAnnotations) { - cls_annotator.annotateArgs(rootClassType, path, - getArgAnnotations(argAnnotations)); - }) + .def( + "annotateArgs", + [&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType, + std::vector path, py::list argAnnotations) { + cls_annotator.annotateArgs( + rootClassType, path, getArgAnnotations(argAnnotations)); + }) .def("__repr__", &ClassAnnotator::toString); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h similarity index 95% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h index a0d1a7581..4eb170b8b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/class_annotator_pybind.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h @@ -18,7 +18,7 @@ namespace py = pybind11; namespace torch_mlir { -void initClassAnnotatorBindings(py::module &m); +void initClassAnnotatorBindings(py::module& m); } // namespace torch_mlir #endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp similarity index 88% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp index 4a538fbcb..31d560a73 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp @@ -21,9 +21,9 @@ using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( - MlirContext context, torch::jit::Function *function, + MlirContext context, torch::jit::Function* function, std::function getArgAttribute, - const ImportOptions &importOptions) { + const ImportOptions& importOptions) { // Useful for debugging: // graph->dump(); MlirLocation loc = mlirLocationUnknownGet(context); @@ -63,10 +63,11 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( } auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { - createMlirOperationAtEnd(appendToBlock, "func.return", loc, - adjustStaticInformationForValues( - appendToBlock, loc, yieldedValues, resultTypes, - /*userAllowsRefinement=*/false)); + createMlirOperationAtEnd( + appendToBlock, "func.return", loc, + adjustStaticInformationForValues( + appendToBlock, loc, yieldedValues, resultTypes, + /*userAllowsRefinement=*/false)); }; MlirBlock block = importBlock( context, torch::jit::toGraphFunction(*function).graph()->block(), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h similarity index 94% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h index 626068f76..a211f6c46 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/function_importer.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h @@ -40,10 +40,10 @@ namespace torch_mlir { /// null MlirAttribute is returned, no attribute will be attached to that /// argument. MlirOperation importJitFunctionAsFuncOp( - MlirContext context, torch::jit::Function *function, + MlirContext context, torch::jit::Function* function, std::function getArgAttribute = [](int) -> MlirAttribute { return {nullptr}; }, - const ImportOptions &importOptions = {}); + const ImportOptions& importOptions = {}); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp similarity index 89% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp index 2b90b3b65..a168ca1c0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp @@ -50,9 +50,9 @@ static py::list getRegisteredOps() { // since the JIT has its own dispatch mechanism that it uses to implement // "prim" ops and a handful of "aten" ops that are effectively prim ops, such // as `aten::__is__`. - for (const std::shared_ptr &op : + for (const std::shared_ptr& op : torch::jit::getAllOperators()) { - const c10::FunctionSchema &schema = op->schema(); + const c10::FunctionSchema& schema = op->schema(); py::dict record; { @@ -69,7 +69,7 @@ static py::list getRegisteredOps() { py::list arguments; py::list returns; - auto addArgument = [](py::list &container, const c10::Argument &arg) { + auto addArgument = [](py::list& container, const c10::Argument& arg) { py::dict argRecord; argRecord["name"] = arg.name(); argRecord["type"] = arg.type()->str(); @@ -87,10 +87,10 @@ static py::list getRegisteredOps() { py::dict aliasInfo; py::list before; py::list after; - for (auto &symbol : arg.alias_info()->beforeSets()) { + for (auto& symbol : arg.alias_info()->beforeSets()) { before.append(std::string(symbol.toQualString())); } - for (auto &symbol : arg.alias_info()->afterSets()) { + for (auto& symbol : arg.alias_info()->afterSets()) { after.append(std::string(symbol.toQualString())); } aliasInfo["is_write"] = arg.alias_info()->isWrite(); @@ -101,10 +101,10 @@ static py::list getRegisteredOps() { container.append(std::move(argRecord)); }; - for (auto &argument : schema.arguments()) { + for (auto& argument : schema.arguments()) { addArgument(arguments, argument); } - for (auto &returnArg : schema.returns()) { + for (auto& returnArg : schema.returns()) { addArgument(returns, returnArg); } record["arguments"] = std::move(arguments); @@ -115,6 +115,6 @@ static py::list getRegisteredOps() { return results; } -void torch_mlir::initGetRegisteredOpsBindings(py::module &m) { +void torch_mlir::initGetRegisteredOpsBindings(py::module& m) { m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h similarity index 94% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h index ec336878c..b2851e6a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/get_registered_ops.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h @@ -19,7 +19,7 @@ namespace torch_mlir { -void initGetRegisteredOpsBindings(py::module &m); +void initGetRegisteredOpsBindings(py::module& m); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options.h similarity index 100% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options.h diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp similarity index 65% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp index b072b0ed9..3e0183a95 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp @@ -14,11 +14,13 @@ namespace py = pybind11; using namespace torch_mlir; -void torch_mlir::initImportOptionsBindings(py::module &m) { +void torch_mlir::initImportOptionsBindings(py::module& m) { py::class_(m, "ImportOptions") .def(py::init<>()) - .def_readwrite("assumeTensorsHaveValueSemantics", - &ImportOptions::assumeTensorsHaveValueSemantics) - .def_readwrite("ignoreExistingTensorShapesAndDtypes", - &ImportOptions::ignoreExistingTensorShapesAndDtypes); + .def_readwrite( + "assumeTensorsHaveValueSemantics", + &ImportOptions::assumeTensorsHaveValueSemantics) + .def_readwrite( + "ignoreExistingTensorShapesAndDtypes", + &ImportOptions::ignoreExistingTensorShapesAndDtypes); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h similarity index 92% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h index 6e8e1389c..4ca27a218 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/import_options_pybind.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h @@ -13,7 +13,7 @@ #include namespace torch_mlir { -void initImportOptionsBindings(pybind11::module &m); +void initImportOptionsBindings(pybind11::module& m); } // namespace torch_mlir #endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/init_python_bindings.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/init_python_bindings.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/init_python_bindings.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/init_python_bindings.cpp diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp similarity index 85% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp index 75013d5ee..73321817e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp @@ -49,10 +49,10 @@ using namespace torch_mlir; // throw an error on). namespace { struct IValueHasher { - size_t operator()(const c10::IValue &ivalue) const { + size_t operator()(const c10::IValue& ivalue) const { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { - return std::hash()( - static_cast(ivalue.internalToPointer())); + return std::hash()( + static_cast(ivalue.internalToPointer())); } return c10::IValue::hash(ivalue); @@ -65,7 +65,7 @@ struct IValueHasher { // such as when tracing). Can we do better? namespace { struct IValueEq { - bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const { + bool operator()(const c10::IValue& lhs, const c10::IValue& rhs) const { return lhs.isSameIdentity(rhs); } }; @@ -99,8 +99,9 @@ namespace { /// (PyTorch allows this!). class IValueImporter { public: - IValueImporter(MlirBlock importBlock, MlirContext context, - ClassAnnotator &annotator, const ImportOptions &importOptions) + IValueImporter( + MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator, + const ImportOptions& importOptions) : importBlock(importBlock), context(context), annotator(annotator), importOptions(importOptions) {} @@ -110,15 +111,16 @@ private: MlirValue rawImportIValue(c10::IValue ivalue); MlirValue importTensor(c10::IValue ivalue); MlirValue importModule(torch::jit::Module jitModule); - void importMethod(torch::jit::Function *function, MlirBlock classTypeBody, - const MethodAnnotation &methodAnnotation); - void importClassType(c10::ClassType *classType); - void importCompilationUnit(torch::jit::CompilationUnit *cu); + void importMethod( + torch::jit::Function* function, MlirBlock classTypeBody, + const MethodAnnotation& methodAnnotation); + void importClassType(c10::ClassType* classType); + void importCompilationUnit(torch::jit::CompilationUnit* cu); MlirBlock importBlock; MlirContext context; - ClassAnnotator &annotator; - const ImportOptions &importOptions; + ClassAnnotator& annotator; + const ImportOptions& importOptions; // Map tracking already-imported values. std::unordered_map valueMap; @@ -129,16 +131,16 @@ private: // e.g. methods (the function names are meaningful and match with Python's // module hierarchy, with the exception of `__main__` being replaced with // `__torch__`). - torch::jit::CompilationUnit *compilationUnit = nullptr; + torch::jit::CompilationUnit* compilationUnit = nullptr; // Used to detect potentially aliasing tensors. - std::unordered_set seenStorageImpls; + std::unordered_set seenStorageImpls; // The set of ClassType's that have already been imported. // // ClassType's are referenced via their `classType->name()->qualifiedName()` // string (as an MLIR symbol name) so we don't need to keep a map associating // them with the MlirOperation that they import into. - std::unordered_set classTypes; + std::unordered_set classTypes; // The stack of attribute names we have traversed to reach the current IValue. // Used for diagnostics. std::vector attributeNameStack; @@ -190,7 +192,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); - mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); + mlirRegionAppendOwnedBlock( + nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); InserterGuard inserterGuard(importBlock, nnModule); @@ -198,13 +201,14 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { rootModuleName = moduleTypeName; } - const std::vector &slots = currentModule._ivalue()->slots(); - const std::vector &classAttributes = + const std::vector& slots = currentModule._ivalue()->slots(); + const std::vector& classAttributes = currentModule.type()->getAttributes(); - assert(slots.size() == classAttributes.size() && - "mismatch between object and type!"); + assert( + slots.size() == classAttributes.size() && + "mismatch between object and type!"); for (int i = 0, e = slots.size(); i < e; i++) { - const c10::ClassAttribute &classAttribute = classAttributes[i]; + const c10::ClassAttribute& classAttribute = classAttributes[i]; attributeNameStack.push_back(classAttribute.getName()); MlirValue slotValue = importIValue(slots[i]); // TODO: Is it necessary to track whether an attribute is a "parameter"? @@ -231,7 +235,7 @@ MlirValue IValueImporter::importIValue(c10::IValue ivalue) { } // Reject potentially aliased tensors. if (ivalue.isTensor()) { - c10::StorageImpl *storageImpl = + c10::StorageImpl* storageImpl = ivalue.toTensor().storage().unsafeGetStorageImpl(); if (!seenStorageImpls.insert(storageImpl).second) { std::stringstream msg; @@ -257,8 +261,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirType type = torchMlirTorchBoolTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.bool", loc, type, - toMlirNamedAttribute("value", - mlirBoolAttrGet(context, ivalue.toBool()))); + toMlirNamedAttribute( + "value", mlirBoolAttrGet(context, ivalue.toBool()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isDouble()) { @@ -266,23 +270,23 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.float", loc, type, toMlirNamedAttribute( - "value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), - ivalue.toDouble()))); + "value", mlirFloatAttrDoubleGet( + context, mlirF64TypeGet(context), ivalue.toDouble()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isInt()) { MlirType type = torchMlirTorchIntTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.int", loc, type, - toMlirNamedAttribute("value", - mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), - ivalue.toInt()))); + toMlirNamedAttribute( + "value", mlirIntegerAttrGet( + mlirIntegerTypeGet(context, 64), ivalue.toInt()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isList()) { c10::List list = ivalue.toList(); std::vector elems; - for (const c10::IValue &elem : list) { + for (const c10::IValue& elem : list) { elems.push_back(importIValue(elem)); } MlirOperation operation = createMlirOperationAtEnd( @@ -312,7 +316,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { auto list = ivalue.toTuple()->elements(); std::vector operands; std::vector types; - for (const c10::IValue &elem : list) { + for (const c10::IValue& elem : list) { MlirValue operand = importIValue(elem); operands.push_back(operand); types.push_back(mlirValueGetType(operand)); @@ -335,14 +339,14 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { torchMlirTorchStringTypeGet(context), toMlirNamedAttribute( "value", - mlirStringAttrGet(context, - toMlirStringRef(ivalue.toString()->string())))); + mlirStringAttrGet( + context, toMlirStringRef(ivalue.toString()->string())))); return mlirOperationGetResult(operation, 0); } if (ivalue.isNone()) { - MlirOperation operation = - createMlirOperationAtEnd(importBlock, "torch.constant.none", loc, - torchMlirTorchNoneTypeGet(context)); + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "torch.constant.none", loc, + torchMlirTorchNoneTypeGet(context)); return mlirOperationGetResult(operation, 0); } if (ivalue.isCustomClass()) { @@ -436,12 +440,12 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { return tensorValue; } -void IValueImporter::importMethod(torch::jit::Function *function, - MlirBlock classTypeBody, - const MethodAnnotation &methodAnnotation) { +void IValueImporter::importMethod( + torch::jit::Function* function, MlirBlock classTypeBody, + const MethodAnnotation& methodAnnotation) { // The function's name becomes the MLIR symbol table name of the imported func // when we import the compilation unit. - const std::string &symName = function->qualname().qualifiedName(); + const std::string& symName = function->qualname().qualifiedName(); MlirAttribute functionSymbolRef = mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)); @@ -457,7 +461,7 @@ void IValueImporter::importMethod(torch::jit::Function *function, toMlirNamedAttribute("function", functionSymbolRef), isPrivate); } -void IValueImporter::importClassType(c10::ClassType *classType) { +void IValueImporter::importClassType(c10::ClassType* classType) { if (!classTypes.insert(classType).second) { return; } @@ -475,13 +479,13 @@ void IValueImporter::importClassType(c10::ClassType *classType) { mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr)); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); - ClassAnnotation &classAnnotation = + ClassAnnotation& classAnnotation = annotator.getOrCreateClassAnnotation(classType); - const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations(); - const auto &classAttributes = classType->getAttributes(); + const auto& attributeAnnotations = classAnnotation.getAttributeAnnotations(); + const auto& classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { - const c10::ClassAttribute &classAttribute = classAttributes[i]; + const c10::ClassAttribute& classAttribute = classAttributes[i]; c10::optional isPrivate; if (!attributeAnnotations[i].isExported) { isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); @@ -491,13 +495,14 @@ void IValueImporter::importClassType(c10::ClassType *classType) { toMlirNamedAttribute( "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName()))), - toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType( - loc, classAttribute.getType(), importOptions))), + toMlirNamedAttribute( + "type", mlirTypeAttrGet(getMlirTypeFromTorchType( + loc, classAttribute.getType(), importOptions))), isPrivate); } - const auto &methodAnnotations = classAnnotation.getMethodAnnotations(); - const auto &methods = classType->methods(); + const auto& methodAnnotations = classAnnotation.getMethodAnnotations(); + const auto& methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { importMethod(methods[i], classTypeBody, methodAnnotations[i]); } @@ -505,7 +510,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) { createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); } -void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { +void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { if (compilationUnit == nullptr) { compilationUnit = cu; } else { @@ -524,14 +529,14 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { return; } - for (torch::jit::Function *function : cu->get_functions()) { + for (torch::jit::Function* function : cu->get_functions()) { // Useful for debugging errors in free functions that end up being // unused. These can be missing when round-tripping through the on-disk // format, even though they still cause import issues when importing // through the larger Python session where they originate. // std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n"; // std::cerr << *torch::jit::toGraphFunction(function).graph(); - MethodAnnotation *annotation = + MethodAnnotation* annotation = annotator.getMethodAnnotationForFunction(function); MlirOperation func = importJitFunctionAsFuncOp( context, function, @@ -539,9 +544,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { if (!annotation || !annotation->argAnnotations.has_value()) { return {nullptr}; } - c10::optional> &maybeShape = + c10::optional>& maybeShape = annotation->argAnnotations.value()[argIndex].shape; - c10::optional &maybeDtype = + c10::optional& maybeDtype = annotation->argAnnotations.value()[argIndex].dtype; bool hasValueSemantics = annotation->argAnnotations.value()[argIndex].hasValueSemantics; @@ -561,10 +566,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { // the C API constructor, when we want the "we know we have 0 sizes" // case. So use a dummy data pointer. int64_t dummy; - int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data(); + int64_t* shapeData = shape.size() == 0 ? &dummy : shape.data(); if (hasValueSemantics) { - typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(), - shapeData, dtype); + typeBound = torchMlirTorchValueTensorTypeGet( + context, shape.size(), shapeData, dtype); } else { typeBound = torchMlirTorchNonValueTensorTypeGet( context, shape.size(), shapeData, dtype); @@ -592,10 +597,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { } } -MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, - MlirContext context, - ClassAnnotator &annotator, - const ImportOptions &importOptions) { +MlirValue torch_mlir::importIValue( + c10::IValue ivalue, MlirBlock block, MlirContext context, + ClassAnnotator& annotator, const ImportOptions& importOptions) { // When debugging module importing, it can be useful to dump as so: // if (ivalue.isModule()) // ivalue.toModule().dump(true, false, false); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h similarity index 83% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h index 7cbc7ece8..ae3deb945 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h @@ -25,9 +25,9 @@ namespace torch_mlir { /// Main entry-point for importing torch IValue's . /// Recursively imports `ivalue`, inserting operations at the end of `block`. -MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, - ClassAnnotator &annotator, - const ImportOptions &importOptions); +MlirValue importIValue( + c10::IValue ivalue, MlirBlock block, MlirContext context, + ClassAnnotator& annotator, const ImportOptions& importOptions); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/mlir_utils.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h similarity index 54% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/mlir_utils.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h index 97ce5fa10..1e033f0d8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/mlir_utils.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h @@ -22,92 +22,92 @@ namespace torch_mlir { -inline MlirStringRef toMlirStringRef(const std::string &s) { +inline MlirStringRef toMlirStringRef(const std::string& s) { return mlirStringRefCreate(s.data(), s.size()); } -inline MlirStringRef toMlirStringRef(const char *s) { +inline MlirStringRef toMlirStringRef(const char* s) { return mlirStringRefCreate(s, std::strlen(s)); } -inline MlirNamedAttribute toMlirNamedAttribute(const char *s, - MlirAttribute attr) { +inline MlirNamedAttribute +toMlirNamedAttribute(const char* s, MlirAttribute attr) { MlirContext context = mlirAttributeGetContext(attr); MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); return mlirNamedAttributeGet(ident, attr); } -inline void addToMlirOperationState(MlirOperationState &state, - MlirNamedAttribute namedAttr) { +inline void addToMlirOperationState( + MlirOperationState& state, MlirNamedAttribute namedAttr) { mlirOperationStateAddAttributes(&state, 1, &namedAttr); } -inline void addToMlirOperationState(MlirOperationState &state, - MlirRegion region) { +inline void +addToMlirOperationState(MlirOperationState& state, MlirRegion region) { mlirOperationStateAddOwnedRegions(&state, 1, ®ion); } -inline void addToMlirOperationState(MlirOperationState &state, - MlirValue value) { +inline void +addToMlirOperationState(MlirOperationState& state, MlirValue value) { mlirOperationStateAddOperands(&state, 1, &value); } -inline void addToMlirOperationState(MlirOperationState &state, - const std::vector &values) { +inline void addToMlirOperationState( + MlirOperationState& state, const std::vector& values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } -inline void addToMlirOperationState(MlirOperationState &state, - c10::ArrayRef values) { +inline void addToMlirOperationState( + MlirOperationState& state, c10::ArrayRef values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } -inline void addToMlirOperationState(MlirOperationState &state, - MlirType resultType) { +inline void +addToMlirOperationState(MlirOperationState& state, MlirType resultType) { mlirOperationStateAddResults(&state, 1, &resultType); } -inline void addToMlirOperationState(MlirOperationState &state, - const std::vector &resultTypes) { +inline void addToMlirOperationState( + MlirOperationState& state, const std::vector& resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } -inline void addToMlirOperationState(MlirOperationState &state, - c10::ArrayRef resultTypes) { +inline void addToMlirOperationState( + MlirOperationState& state, c10::ArrayRef resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } template -void addToMlirOperationState(MlirOperationState &state, c10::optional o) { +void addToMlirOperationState(MlirOperationState& state, c10::optional o) { if (o.has_value()) { addToMlirOperationState(state, o.value()); } } -inline void addToMlirOperationState(MlirOperationState &state) {} +inline void addToMlirOperationState(MlirOperationState& state) {} template -void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, - Ts &&...ts) { +void addToMlirOperationState( + MlirOperationState& state, T&& t, U&& u, Ts&&... ts) { addToMlirOperationState(state, std::forward(t)); addToMlirOperationState(state, std::forward(u), std::forward(ts)...); } template -MlirOperation createMlirOperation(std::string name, MlirLocation loc, - Ts &&...ts) { +MlirOperation +createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) { MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); addToMlirOperationState(state, std::forward(ts)...); return mlirOperationCreate(&state); } template -MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, - MlirLocation loc, Ts &&...ts) { +MlirOperation createMlirOperationAtEnd( + MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) { MlirOperation operation = createMlirOperation(name, loc, std::forward(ts)...); - mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), - operation); + mlirBlockInsertOwnedOperationBefore( + block, mlirBlockGetTerminator(block), operation); return operation; } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp similarity index 76% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp index ca4bd600f..c1922f8f0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp @@ -22,7 +22,7 @@ namespace py = pybind11; using namespace torch_mlir; -static py::object getMlirIrClass(const char *className) { +static py::object getMlirIrClass(const char* className) { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className); } @@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) { return contextObj; } -static MlirContext castPythonObjectToMlirContext(py::object &contextObj) { +static MlirContext castPythonObjectToMlirContext(py::object& contextObj) { assert(!contextObj.is_none() && "context cannot be None"); auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR); MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr()); @@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) { std::stringstream ss; ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic)) << ": "; - auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) { - auto *ssp = static_cast(stringCallbackUserData); + auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) { + auto* ssp = static_cast(stringCallbackUserData); ssp->write(s.data, s.length); }; - mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); + mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); // Use pybind11's print: // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html - py::print(ss.str(), - py::arg("file") = py::module_::import("sys").attr("stderr")); + py::print( + ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr")); } // Register a diagnostic handler that will redirect output to `sys.stderr` @@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) { // that mlir diagnostics emitted are correctly routed in Jupyter notebooks. static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { auto diagnosticHandler = [](MlirDiagnostic diagnostic, - void *) -> MlirLogicalResult { + void*) -> MlirLogicalResult { printDiagnostic(diagnostic); for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) { printDiagnostic(mlirDiagnosticGetNote(diagnostic, i)); @@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { return mlirLogicalResultSuccess(); }; MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( - context, diagnosticHandler, nullptr, [](void *) { return; }); + context, diagnosticHandler, nullptr, [](void*) { return; }); // Ignore the ID. We intend to keep this handler for the entire lifetime // of this context. (void)id; @@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj) terminator = mlirBlockGetFirstOperation(getBodyBlock()); } -torch::jit::StrongFunctionPtr -ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function, - py::object maybeImportOptions) { +torch::jit::StrongFunctionPtr ModuleBuilder::importFunction( + torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) { ImportOptions importOptions; if (!maybeImportOptions.is_none()) { importOptions = py::cast(maybeImportOptions); } MlirBlock block = getBodyBlock(); MlirOperation terminator = this->terminator; - MlirOperation func = importJitFunctionAsFuncOp(context, function.function_, - [](int) -> MlirAttribute { return {nullptr}; }, importOptions); + MlirOperation func = importJitFunctionAsFuncOp( + context, function.function_, + [](int) -> MlirAttribute { return {nullptr}; }, importOptions); mlirBlockInsertOwnedOperationBefore(block, terminator, func); return function; } -void ModuleBuilder::importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator, - py::object maybeImportOptions) { +void ModuleBuilder::importModule( + torch::jit::Module jitModule, py::object maybeClassAnnotator, + py::object maybeImportOptions) { ClassAnnotator dummyAnnotator; - ClassAnnotator *classAnnotator = &dummyAnnotator; + ClassAnnotator* classAnnotator = &dummyAnnotator; if (!maybeClassAnnotator.is_none()) { - classAnnotator = py::cast(maybeClassAnnotator); + classAnnotator = py::cast(maybeClassAnnotator); } ImportOptions importOptions; if (!maybeImportOptions.is_none()) { @@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule, // precise `torch.class_type` names. // // This name is not semantically load-bearing!!! - auto &name = *jitModule.type()->name(); + auto& name = *jitModule.type()->name(); auto debugModuleNameAttr = mlirStringAttrGet( context, toMlirStringRef(name.atoms()[name.atoms().size() - 1])); - mlirOperationSetAttributeByName(mlirModuleGetOperation(module), - toMlirStringRef("torch.debug_module_name"), - debugModuleNameAttr); - importIValue(jitModule._ivalue(), mlirModuleGetBody(module), - mlirModuleGetContext(module), *classAnnotator, importOptions); + mlirOperationSetAttributeByName( + mlirModuleGetOperation(module), + toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr); + importIValue( + jitModule._ivalue(), mlirModuleGetBody(module), + mlirModuleGetContext(module), *classAnnotator, importOptions); } MlirBlock ModuleBuilder::getBodyBlock() { @@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() { return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0)); } -void ModuleBuilder::bind(py::module &m) { +void ModuleBuilder::bind(py::module& m) { py::class_(m, "ModuleBuilder") .def(py::init(), py::arg("context") = py::none()) .def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj) - .def("import_function", &ModuleBuilder::importFunction, py::arg("function"), - py::arg("importOptions") = py::none()) - .def("import_module", &ModuleBuilder::importModule, py::arg("module"), - py::arg("classAnnotator") = py::none(), - py::arg("importOptions") = py::none()); + .def( + "import_function", &ModuleBuilder::importFunction, + py::arg("function"), py::arg("importOptions") = py::none()) + .def( + "import_module", &ModuleBuilder::importModule, py::arg("module"), + py::arg("classAnnotator") = py::none(), + py::arg("importOptions") = py::none()); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h similarity index 85% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h index 08695e15f..273778c41 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/module_builder.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h @@ -29,7 +29,7 @@ public: ModuleBuilder(pybind11::object contextObj); /// Creates Python bindings for the class. - static void bind(pybind11::module &m); + static void bind(pybind11::module& m); pybind11::object getContextObj() { return contextObj; } pybind11::object getModuleObj() { return moduleObj; } @@ -38,16 +38,15 @@ public: // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. // Just a bit of naming cruft. // Returns the same function, making it suitable as a nested decorator. - torch::jit::StrongFunctionPtr - importFunction(torch::jit::StrongFunctionPtr function, - py::object maybeImportOptions); + torch::jit::StrongFunctionPtr importFunction( + torch::jit::StrongFunctionPtr function, py::object maybeImportOptions); // Imports a torch::jit::Module into the current module, using the // annotations, if not none, provided in `maybeClassAnnotator` which should be // a ClassAnnotator. - void importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator, - py::object maybeImportOptions); + void importModule( + torch::jit::Module jitModule, py::object maybeClassAnnotator, + py::object maybeImportOptions); private: MlirBlock getBodyBlock(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp similarity index 67% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp index 15cffedbe..e9be84acc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp @@ -33,41 +33,42 @@ class NodeImporter { public: NodeImporter(MlirContext context) : context(context) {} - void importNode(Node *node, MlirBlock appendToBlock, - const ImportOptions &importOptions = {}); + void importNode( + Node* node, MlirBlock appendToBlock, + const ImportOptions& importOptions = {}); MlirBlock importBlock( - Block *jitBlock, CreateTerminatorFn createTerminator, + Block* jitBlock, CreateTerminatorFn createTerminator, c10::optional> blockArgTypes = c10::nullopt, - const ImportOptions &importOptions = {}); + const ImportOptions& importOptions = {}); private: - MlirBlock - createBlockFor(Block *jitBlock, - c10::optional> blockArgTypes, - const ImportOptions &importOptions = {}); - void mapValue(Value *jitValue, MlirValue value); - void mapResults(Node *node, MlirOperation operation); - MlirValue lookupMappedValue(Value *jitValue); - std::vector lookupMappedValues(c10::ArrayRef values); + MlirBlock createBlockFor( + Block* jitBlock, c10::optional> blockArgTypes, + const ImportOptions& importOptions = {}); + void mapValue(Value* jitValue, MlirValue value); + void mapResults(Node* node, MlirOperation operation); + MlirValue lookupMappedValue(Value* jitValue); + std::vector lookupMappedValues(c10::ArrayRef values); MlirContext context; - std::unordered_map valueMap; + std::unordered_map valueMap; }; } // namespace using InputsTransformFn = - std::function(std::vector &)>; + std::function(std::vector&)>; // The inputs of `DictConstruct` in TorchScript IR are in the order // like k0, v0, k1, v1. Rearrange them to put the key operands together and // then the value operands like k0, k1,v0, v1. This is the expected format by // the corresponding MLIR op. static std::vector -rearrangeDictConstructInputs(std::vector &inputs) { +rearrangeDictConstructInputs(std::vector& inputs) { if (inputs.empty()) return inputs; - assert(inputs.size() % 2 == 0 && - "DictConstruct must have even number of operands"); + assert( + inputs.size() % 2 == 0 && + "DictConstruct must have even number of operands"); std::vector rearranged; std::vector values; @@ -79,12 +80,12 @@ rearrangeDictConstructInputs(std::vector &inputs) { return rearranged; } -void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, - const ImportOptions &importOptions) { +void NodeImporter::importNode( + Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) { MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); - auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, + auto createAndMapTrivialNode = [&](Node* node, const std::string& opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); MlirOperation operation = createMlirOperationAtEnd( @@ -95,7 +96,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, }; auto createAndMapNodeWithAttribute = - [&](Node *node, const std::string &opName, const std::string &attrName, + [&](Node* node, const std::string& opName, const std::string& attrName, MlirAttribute attr) { MlirOperation operation = createMlirOperationAtEnd( appendToBlock, opName, loc, @@ -132,27 +133,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, // ListConstruct and DictConstruct too. auto containedTypes = c10::fmap( node->output()->type()->cast()->containedTypes(), - [&](const c10::TypePtr &t) { + [&](const c10::TypePtr& t) { MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); if (mlirTypeIsNull(type)) { throw mlir_diagnostic_emitted(); } return type; }); - createAndMapTrivialNode(node, - "torch.prim." + std::string(kind.toUnqualString()), - [&](std::vector &inputs) { - assert(containedTypes.size() == inputs.size()); - return adjustStaticInformationForValues( - appendToBlock, loc, inputs, containedTypes, - /*userAllowsRefinement=*/true); - }); + createAndMapTrivialNode( + node, "torch.prim." + std::string(kind.toUnqualString()), + [&](std::vector& inputs) { + assert(containedTypes.size() == inputs.size()); + return adjustStaticInformationForValues( + appendToBlock, loc, inputs, containedTypes, + /*userAllowsRefinement=*/true); + }); return; } case c10::prim::DictConstruct: { - createAndMapTrivialNode(node, - "torch.prim." + std::string(kind.toUnqualString()), - rearrangeDictConstructInputs); + createAndMapTrivialNode( + node, "torch.prim." + std::string(kind.toUnqualString()), + rearrangeDictConstructInputs); return; } case c10::prim::Load: @@ -170,32 +171,34 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, auto output = node->output(); MlirOperation op; if (output->type()->cast()) { - op = createMlirOperation("torch.constant.none", loc, - torchMlirTorchNoneTypeGet(context)); + op = createMlirOperation( + "torch.constant.none", loc, torchMlirTorchNoneTypeGet(context)); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context), toMlirNamedAttribute( - "value", mlirBoolAttrGet(context, static_cast(node->i( - c10::attr::value))))); + "value", + mlirBoolAttrGet( + context, static_cast(node->i(c10::attr::value))))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.int", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), - toMlirNamedAttribute("value", - importAttribute(loc, node, c10::attr::value))); + toMlirNamedAttribute( + "value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.float", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), - toMlirNamedAttribute("value", - importAttribute(loc, node, c10::attr::value))); + toMlirNamedAttribute( + "value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.str", loc, torchMlirTorchStringTypeGet(context), toMlirNamedAttribute( - "value", mlirStringAttrGet(context, toMlirStringRef(node->s( - c10::attr::value))))); + "value", + mlirStringAttrGet( + context, toMlirStringRef(node->s(c10::attr::value))))); } else if (output->type()->cast()) { MlirAttribute attr = importAttribute(loc, node, c10::attr::value); if (importOptions.assumeTensorsHaveValueSemantics) { @@ -214,24 +217,26 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, "torch.constant.device", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute( - "value", mlirStringAttrGet(context, toMlirStringRef(node->s( - c10::attr::value))))); + "value", + mlirStringAttrGet( + context, toMlirStringRef(node->s(c10::attr::value))))); } else if (auto functionType = output->type()->cast()) { - torch::jit::Function *function = functionType->function(); - const std::string &symName = function->qualname().qualifiedName(); + torch::jit::Function* function = functionType->function(); + const std::string& symName = function->qualname().qualifiedName(); op = createMlirOperation( "func.constant", loc, - getFunctionTypeFromSchema(context, function->getSchema(), - importOptions), + getFunctionTypeFromSchema( + context, function->getSchema(), importOptions), toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); - } else if (output->type()->cast() || - output->type()->cast()) { + } else if ( + output->type()->cast() || + output->type()->cast()) { ClassAnnotator dummyAnnotator; - MlirValue listOrTupleValue = - importIValue(node->ival(c10::attr::value), appendToBlock, context, - dummyAnnotator, importOptions); + MlirValue listOrTupleValue = importIValue( + node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator, + importOptions); mapResults(node, mlirOpResultGetOwner(listOrTupleValue)); return; // Early return, since `importIValue` already added op to block. } else { @@ -259,19 +264,20 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, mapResults(node, operation); std::vector terminatorOperandTypes = { torchMlirTorchBoolTypeGet(context)}; - terminatorOperandTypes.insert(terminatorOperandTypes.end(), - resultTypes.begin(), resultTypes.end()); + terminatorOperandTypes.insert( + terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end()); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop.condition", loc, - adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, - terminatorOperandTypes, - /*userAllowsRefinement=*/false)); + adjustStaticInformationForValues( + appendToBlock, loc, yieldedValues, terminatorOperandTypes, + /*userAllowsRefinement=*/false)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); + importBlock( + node->blocks()[0], createTerminator, c10::nullopt, importOptions)); return; } @@ -286,25 +292,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.If.yield", loc, - adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, - resultTypes, - /*userAllowsRefinement=*/false)); + adjustStaticInformationForValues( + appendToBlock, loc, yieldedValues, resultTypes, + /*userAllowsRefinement=*/false)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); + importBlock( + node->blocks()[0], createTerminator, c10::nullopt, importOptions)); mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 1), - importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions)); + importBlock( + node->blocks()[1], createTerminator, c10::nullopt, importOptions)); return; } if (kind == c10::prim::CallMethod) { auto classType = node->input(0)->type()->cast(); auto methodName = node->s(c10::attr::name); - torch::jit::Function *function = classType->findMethod(methodName); - MlirType calleeType = - getFunctionTypeFromSchema(context, function->getSchema(), importOptions); + torch::jit::Function* function = classType->findMethod(methodName); + MlirType calleeType = getFunctionTypeFromSchema( + context, function->getSchema(), importOptions); std::vector expectedTypes; for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); @@ -315,17 +323,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs()), expectedTypes, /*userAllowsRefinement=*/false), - toMlirNamedAttribute("name", - importAttribute(loc, node, c10::attr::name))); + toMlirNamedAttribute( + "name", importAttribute(loc, node, c10::attr::name))); mapResults(node, operation); return; } if (kind == c10::prim::CallFunction) { auto functionType = node->input(0)->type()->cast(); - torch::jit::Block *calleeEntryBlock = + torch::jit::Block* calleeEntryBlock = torch::jit::toGraphFunction(*functionType->function()).graph()->block(); - auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { + auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value* v) { return getMlirTypeFromTorchType(loc, v->type(), importOptions); }); std::string functionName = node->input(0)->node()->s(c10::attr::name); @@ -340,9 +348,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, // promoted result dtype for a PyTorch computation. Here we turn the call to // this function to the torch dialect equivalent op `torch.promote_dtypes`. if (functionName == "__torch_mlir_internal_promote_dtypes") { - operation = - createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc, - resultTypes, adjustedFuncArgs); + operation = createMlirOperationAtEnd( + appendToBlock, "torch.promote_dtypes", loc, resultTypes, + adjustedFuncArgs); } else { operation = createMlirOperationAtEnd( appendToBlock, "func.call_indirect", loc, resultTypes, @@ -362,22 +370,22 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, } MlirBlock NodeImporter::importBlock( - Block *jitBlock, CreateTerminatorFn createTerminator, + Block* jitBlock, CreateTerminatorFn createTerminator, c10::optional> blockArgTypes, - const ImportOptions &importOptions) { + const ImportOptions& importOptions) { MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); - for (Node *node : jitBlock->nodes()) { + for (Node* node : jitBlock->nodes()) { importNode(node, block, importOptions); } - Node *returnNode = jitBlock->return_node(); + Node* returnNode = jitBlock->return_node(); createTerminator(lookupMappedValues(returnNode->inputs()), block); return block; } MlirBlock NodeImporter::createBlockFor( - Block *jitBlock, c10::optional> blockArgTypes, - const ImportOptions &importOptions) { - Node *paramNode = jitBlock->param_node(); + Block* jitBlock, c10::optional> blockArgTypes, + const ImportOptions& importOptions) { + Node* paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); std::vector paramNodeTypes = getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); @@ -386,11 +394,11 @@ MlirBlock NodeImporter::createBlockFor( else assert(blockArgTypes->size() == paramNodeTypes.size()); std::vector blockArgLocs(paramNodeTypes.size(), loc); - MlirBlock block = - mlirBlockCreate(blockArgTypes.value().size(), - blockArgTypes.value().data(), blockArgLocs.data()); + MlirBlock block = mlirBlockCreate( + blockArgTypes.value().size(), blockArgTypes.value().data(), + blockArgLocs.data()); for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { - Value *jitValue = paramNode->outputs()[i]; + Value* jitValue = paramNode->outputs()[i]; MlirValue value = mlirBlockGetArgument(block, i); MlirValue adjusted = adjustStaticInformationForValues( block, loc, {value}, {paramNodeTypes[i]}, @@ -400,39 +408,40 @@ MlirBlock NodeImporter::createBlockFor( return block; } -void NodeImporter::mapValue(Value *jitValue, MlirValue value) { +void NodeImporter::mapValue(Value* jitValue, MlirValue value) { auto it = valueMap.find(jitValue); (void)it; assert(it == valueMap.end() && "jitValue has already been mapped"); valueMap[jitValue] = value; } -void NodeImporter::mapResults(Node *node, MlirOperation operation) { - assert(node->outputs().size() == - (size_t)mlirOperationGetNumResults(operation)); +void NodeImporter::mapResults(Node* node, MlirOperation operation) { + assert( + node->outputs().size() == (size_t)mlirOperationGetNumResults(operation)); for (int i = 0, e = node->outputs().size(); i < e; i++) { mapValue(node->outputs()[i], mlirOperationGetResult(operation, i)); } } -MlirValue NodeImporter::lookupMappedValue(Value *jitValue) { +MlirValue NodeImporter::lookupMappedValue(Value* jitValue) { auto it = valueMap.find(jitValue); - assert(it != valueMap.end() && - "trying to get mapping for jitValue that is not mapped yet!"); + assert( + it != valueMap.end() && + "trying to get mapping for jitValue that is not mapped yet!"); return it->second; } std::vector -NodeImporter::lookupMappedValues(c10::ArrayRef values) { +NodeImporter::lookupMappedValues(c10::ArrayRef values) { std::vector ret; - for (Value *value : values) { + for (Value* value : values) { ret.push_back(lookupMappedValue(value)); } return ret; } -MlirBlock -torch_mlir::importBlock(MlirContext context, Block *jitBlock, - CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes, - const ImportOptions &importOptions) { +MlirBlock torch_mlir::importBlock( + MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes, + const ImportOptions& importOptions) { NodeImporter importer(context); - return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions); + return importer.importBlock( + jitBlock, createTerminator, blockArgTypes, importOptions); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h similarity index 94% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h index dd01444f4..f36352058 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/node_importer.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h @@ -37,10 +37,10 @@ using CreateTerminatorFn = /// adjust the types to the block argument types. /// TODO: Formalize what type conversions are allowed here. MlirBlock importBlock( - MlirContext context, torch::jit::Block *jitBlock, + MlirContext context, torch::jit::Block* jitBlock, CreateTerminatorFn createTerminator, c10::optional> blockArgTypes = c10::nullopt, - const ImportOptions &importOptions = {}); + const ImportOptions& importOptions = {}); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp similarity index 74% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp index afac7b164..fc8858734 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp @@ -26,8 +26,8 @@ using namespace torch_mlir; -static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, - c10::ScalarType scalarType) { +static MlirType getMlirTypeForTorchScalarTypeRaw( + MlirContext context, c10::ScalarType scalarType) { using c10::ScalarType; switch (scalarType) { case ScalarType::Byte: @@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, } } -MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc, - c10::ScalarType scalarType) { +MlirType torch_mlir::getMlirTypeForTorchScalarType( + MlirLocation loc, c10::ScalarType scalarType) { auto type = getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType); if (mlirTypeIsNull(type)) { @@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc, // There is no generic way to import custom classes (or their types), so we // have to name match them here (and the relevant code in the ivalue // importer) and create special IR constructs for them. -static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, - const c10::ClassTypePtr &classType) { +static MlirType mapCustomClassType( + MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) { // If the type is unnamed, it cannot be a custom class. if (!classType->name().has_value()) { return {nullptr}; @@ -126,10 +126,9 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, throw mlir_diagnostic_emitted(); } -MlirType -torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType, - const ImportOptions &importOptions) { +MlirType torch_mlir::getMlirTypeFromTorchType( + MlirLocation loc, const c10::TypePtr& torchType, + const ImportOptions& importOptions) { MlirContext context = mlirLocationGetContext(loc); using c10::TypeKind; auto kind = torchType->kind(); @@ -141,10 +140,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, : torchMlirTorchNonValueTensorTypeGet; if (importOptions.ignoreExistingTensorShapesAndDtypes) { - return getMlirTensorType(context, - /*numSizes=*/-1, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/{nullptr}); + return getMlirTensorType( + context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/{nullptr}); } // Element type. @@ -156,17 +156,18 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return {nullptr}; } // Sizes. - auto &sizes = tensorType->symbolic_sizes(); + auto& sizes = tensorType->symbolic_sizes(); if (!sizes.rank()) { // Unranked. - return getMlirTensorType(context, - /*numSizes=*/-1, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/ - elementType); + return getMlirTensorType( + context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/ + elementType); } // Ranked with possibly dynamic dims. - auto &symbolicShape = tensorType->symbolic_sizes(); + auto& symbolicShape = tensorType->symbolic_sizes(); std::vector dims; dims.resize(*sizes.rank()); for (size_t i = 0; i < dims.size(); ++i) { @@ -179,11 +180,12 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, // the C API constructor, when we want the "we know we have 0 sizes" // case. So use a dummy data pointer. int64_t dummy; - int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data(); - return getMlirTensorType(context, dims.size(), - /*optionalSizes=*/dimsData, - /*optionalDtype=*/ - elementType); + int64_t* dimsData = dims.size() == 0 ? &dummy : dims.data(); + return getMlirTensorType( + context, dims.size(), + /*optionalSizes=*/dimsData, + /*optionalDtype=*/ + elementType); } case TypeKind::IntType: { return torchMlirTorchIntTypeGet(context); @@ -207,22 +209,22 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, } case TypeKind::TupleType: { std::vector containedTypes; - for (const c10::TypePtr &type : + for (const c10::TypePtr& type : torchType->cast()->containedTypes()) { containedTypes.push_back( getMlirTypeFromTorchType(loc, type, importOptions)); } - return torchMlirTorchTupleTypeGet(context, containedTypes.size(), - containedTypes.data()); + return torchMlirTorchTupleTypeGet( + context, containedTypes.size(), containedTypes.data()); } case TypeKind::UnionType: { std::vector containedTypes; - for (const c10::TypePtr &type : + for (const c10::TypePtr& type : torchType->cast()->containedTypes()) { containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); } - return torchMlirTorchUnionTypeGet(context, containedTypes.size(), - containedTypes.data()); + return torchMlirTorchUnionTypeGet( + context, containedTypes.size(), containedTypes.data()); } case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( @@ -242,7 +244,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return torchMlirTorchAnyTypeGet(context); } case TypeKind::ClassType: { - const c10::ClassTypePtr &classType = torchType->cast(); + const c10::ClassTypePtr& classType = torchType->cast(); MlirType customClassType = mapCustomClassType(context, loc, classType); if (!mlirTypeIsNull(customClassType)) { return customClassType; @@ -266,12 +268,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, } } -MlirType -torch_mlir::getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema, - const ImportOptions &importOptions) { +MlirType torch_mlir::getFunctionTypeFromSchema( + MlirContext context, const c10::FunctionSchema& schema, + const ImportOptions& importOptions) { MlirLocation loc = mlirLocationUnknownGet(context); - auto mapType = [&](const c10::TypePtr &torchType) { + auto mapType = [&](const c10::TypePtr& torchType) { MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); if (mlirTypeIsNull(type)) { std::stringstream msg; @@ -283,17 +284,20 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context, }; std::vector inputTypes = - c10::fmap(schema.arguments(), - [&](const c10::Argument &arg) { return mapType(arg.type()); }); + c10::fmap(schema.arguments(), [&](const c10::Argument& arg) { + return mapType(arg.type()); + }); std::vector outputTypes = - c10::fmap(schema.returns(), - [&](const c10::Argument &arg) { return mapType(arg.type()); }); - return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), - outputTypes.size(), outputTypes.data()); + c10::fmap(schema.returns(), [&](const c10::Argument& arg) { + return mapType(arg.type()); + }); + return mlirFunctionTypeGet( + context, inputTypes.size(), inputTypes.data(), outputTypes.size(), + outputTypes.data()); } -MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, - MlirLocation loc) { +MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( + at::Tensor tensor, MlirLocation loc) { using at::ScalarType; auto throwUnsupportedTensorError = [&]() { @@ -308,8 +312,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, // The flat number of bytes throws an exception for tensors that are not // dense and accessible as such. - at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor, - c10::Layout::Strided); + at::checkLayout( + at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided); // Construct the ShapedType. @@ -334,47 +338,47 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, switch (tensor.scalar_type()) { case ScalarType::Int: return mlirDenseElementsAttrInt32Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Long: return mlirDenseElementsAttrInt64Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Float: return mlirDenseElementsAttrFloatGet( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Double: return mlirDenseElementsAttrDoubleGet( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Bool: { // TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed // upstream to take in a `const bool *` rather than a `const int *` to avoid // the unnecessary copying into an array four times as large. - const int8_t *elements = static_cast(tensorData); + const int8_t* elements = static_cast(tensorData); std::vector tensorDataVector(elements, elements + numElements); - return mlirDenseElementsAttrBoolGet(shapedType, numElements, - tensorDataVector.data()); + return mlirDenseElementsAttrBoolGet( + shapedType, numElements, tensorDataVector.data()); } break; case ScalarType::QInt8: return mlirDenseElementsAttrInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::QUInt8: return mlirDenseElementsAttrUInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::BFloat16: return mlirDenseElementsAttrBFloat16Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Half: return mlirDenseElementsAttrFloat16Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Byte: return mlirDenseElementsAttrUInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Char: return mlirDenseElementsAttrInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); default: throwUnsupportedTensorError(); @@ -382,9 +386,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, return {nullptr}; // Unreachable. } -MlirAttribute torch_mlir::importAttribute(MlirLocation loc, - torch::jit::Node *node, - c10::Symbol symbol) { +MlirAttribute torch_mlir::importAttribute( + MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) { MlirContext context = mlirLocationGetContext(loc); auto kind = node->kindOf(symbol); switch (kind) { @@ -393,8 +396,8 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc, // do that. return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol)); case torch::jit::AttributeKind::f: - return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), - node->f(symbol)); + return mlirFloatAttrDoubleGet( + context, mlirF64TypeGet(context), node->f(symbol)); case torch::jit::AttributeKind::s: return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol))); case torch::jit::AttributeKind::t: @@ -408,23 +411,23 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc, } } -MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, - torch::jit::Node *node) { +MlirLocation torch_mlir::getMlirLocationFromNode( + MlirContext context, torch::jit::Node* node) { MlirLocation loc = mlirLocationUnknownGet(context); if (node->hasAttribute(c10::Symbol::attr("source_files"))) { - const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files")); - const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers")); - const auto &functions = node->ss(c10::Symbol::attr("functions")); + const auto& sourceFiles = node->ss(c10::Symbol::attr("source_files")); + const auto& lineNumbers = node->is(c10::Symbol::attr("line_numbers")); + const auto& functions = node->ss(c10::Symbol::attr("functions")); // Chain a sequence of calls to construct single MlirLocation. for (const auto i : c10::irange(sourceFiles.size())) { MlirLocation newLoc = mlirLocationNameGet( context, toMlirStringRef(functions[i]), - mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]), - lineNumbers[i], - 0 /* column is not available */ - )); + mlirLocationFileLineColGet( + context, toMlirStringRef(sourceFiles[i]), lineNumbers[i], + 0 /* column is not available */ + )); loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); } if (sourceFiles.size() == 1) { @@ -433,7 +436,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); } } else if (auto flc = node->sourceRange().file_line_col()) { - const std::string &file = std::get<0>(*flc); + const std::string& file = std::get<0>(*flc); int line = std::get<1>(*flc); int col = std::get<2>(*flc); loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); @@ -445,7 +448,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, locationName = scopeName; } - if (const c10::FunctionSchema *schema = node->maybeSchema()) { + if (const c10::FunctionSchema* schema = node->maybeSchema()) { if (!locationName.empty()) { locationName += "/"; } @@ -459,10 +462,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, return loc; } -std::vector -torch_mlir::getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values, - const ImportOptions &importOptions) { +std::vector torch_mlir::getMlirTypesFromValues( + MlirLocation loc, c10::ArrayRef values, + const ImportOptions& importOptions) { std::vector ret; for (auto value : values) { MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); @@ -491,25 +493,24 @@ std::vector torch_mlir::adjustStaticInformationForValues( } std::stringstream msg; - MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { - std::stringstream *stream = static_cast(userData); + MlirStringCallback printToStream = +[](MlirStringRef str, void* userData) { + std::stringstream* stream = static_cast(userData); stream->write(str.data, str.length); }; msg << "unhandled: could not adjust static info for type from "; - mlirTypePrint(type, printToStream, static_cast(&msg)); + mlirTypePrint(type, printToStream, static_cast(&msg)); msg << " to type "; - mlirTypePrint(expectedType, printToStream, static_cast(&msg)); + mlirTypePrint(expectedType, printToStream, static_cast(&msg)); mlirEmitError(loc, msg.str().c_str()); throw mlir_diagnostic_emitted(); } return ret; } -MlirOperation -torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc, - const c10::FunctionSchema &schema, - c10::ArrayRef resultTypes, - c10::ArrayRef operands) { +MlirOperation torch_mlir::createOperationFromSchema( + MlirBlock appendToBlock, MlirLocation loc, + const c10::FunctionSchema& schema, c10::ArrayRef resultTypes, + c10::ArrayRef operands) { MlirContext context = mlirLocationGetContext(loc); // Munge the name into the appropriate MLIR operation name. @@ -519,15 +520,15 @@ torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc, auto separatorPosition = opNameSuffix.find_first_of("::"); assert(separatorPosition != std::string::npos); opNameSuffix.replace(separatorPosition, 2, "."); - const std::string &overloadName = schema.overload_name(); + const std::string& overloadName = schema.overload_name(); if (!overloadName.empty()) { opNameSuffix = opNameSuffix + "." + overloadName; } std::string opName = "torch." + opNameSuffix; // If we have a registered op, use it! if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) { - return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes, - operands); + return createMlirOperationAtEnd( + appendToBlock, opName, loc, resultTypes, operands); } // Oops, no registered op -- create an opaque wrapper so that import can // still succeed. This helps a common use case of filling out registered ops diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h similarity index 61% rename from projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h index 82f394999..eea49999b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/csrc/torch_to_mlir_utils.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h @@ -25,7 +25,7 @@ namespace torch_mlir { /// Thrown on failure when details are in MLIR emitted diagnostics. class mlir_diagnostic_emitted : public std::runtime_error { public: - mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} + mlir_diagnostic_emitted(const char* what) : std::runtime_error(what) {} mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} }; @@ -38,37 +38,36 @@ public: /// for Python code). /// /// Returns a null type on failure and emits a diagnostic. -MlirType getMlirTypeForTorchScalarType(MlirLocation loc, - c10::ScalarType scalarType); +MlirType +getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType); /// Maps a torch type to a corresponding MlirType. Returns a null type /// on failure and emits a diagnostic. -MlirType getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType, - const ImportOptions &importOptions = {}); +MlirType getMlirTypeFromTorchType( + MlirLocation loc, const c10::TypePtr& torchType, + const ImportOptions& importOptions = {}); /// Creates a FunctionType suitable for expressing the signature of `schema`. /// /// This can differ from the type inferred from the block of a /// torch::jit::Function due to derefinement and refinement of tensor types. -MlirType getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema, - const ImportOptions &importOptions = {}); +MlirType getFunctionTypeFromSchema( + MlirContext context, const c10::FunctionSchema& schema, + const ImportOptions& importOptions = {}); /// Creates an appropriate MlirAttribute that holds the same values as `tensor`. -MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor, - MlirLocation loc); +MlirAttribute +convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc); -MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node, - c10::Symbol symbol); +MlirAttribute +importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol); -MlirLocation getMlirLocationFromNode(MlirContext context, - torch::jit::Node *node); +MlirLocation +getMlirLocationFromNode(MlirContext context, torch::jit::Node* node); -std::vector -getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values, - const ImportOptions &importOptions = {}); +std::vector getMlirTypesFromValues( + MlirLocation loc, c10::ArrayRef values, + const ImportOptions& importOptions = {}); std::vector adjustStaticInformationForValues( MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values, @@ -79,11 +78,10 @@ std::vector adjustStaticInformationForValues( /// /// The primary difficulty here is doing the appropriate name munging and /// checking if the have a registered op. -MlirOperation createOperationFromSchema(MlirBlock appendToBlock, - MlirLocation loc, - const c10::FunctionSchema &schema, - c10::ArrayRef resultTypes, - c10::ArrayRef operands); +MlirOperation createOperationFromSchema( + MlirBlock appendToBlock, MlirLocation loc, + const c10::FunctionSchema& schema, c10::ArrayRef resultTypes, + c10::ArrayRef operands); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 1064a3d1e..4bcb9347b 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -14,12 +14,12 @@ #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include "backend_impl.h" diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index c575d9dd2..f4b8cd9ba 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -11,10 +11,10 @@ #include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt index 36bd8cafd..c2883b3dc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt @@ -2,8 +2,6 @@ # Subdirectories #------------------------------------------------------------------------------- -add_subdirectory(csrc) - ## Declare the sources of the Python module. declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter