diff --git a/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt b/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt index ccd114f8b..c3a18951f 100644 --- a/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt +++ b/external/torch-mlir/TorchPlugin/csrc/CMakeLists.txt @@ -13,14 +13,13 @@ include_directories(BEFORE link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(TorchMLIRTorchPlugin SHARED - builder/class_annotator.cpp - builder/function_importer.cpp - builder/module_builder.cpp - builder/node_importer.cpp - builder/ivalue_importer.cpp - builder/python_bindings.cpp - builder/torch_to_mlir_utils.cpp - init_python_bindings.cpp + class_annotator.cpp + function_importer.cpp + module_builder.cpp + node_importer.cpp + ivalue_importer.cpp + python_bindings.cpp + torch_to_mlir_utils.cpp ) target_link_libraries(TorchMLIRTorchPlugin diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.cpp b/external/torch-mlir/TorchPlugin/csrc/class_annotator.cpp similarity index 90% rename from external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.cpp rename to external/torch-mlir/TorchPlugin/csrc/class_annotator.cpp index e38f52a6e..0de9a1057 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.cpp +++ b/external/torch-mlir/TorchPlugin/csrc/class_annotator.cpp @@ -34,13 +34,12 @@ static std::string indentString(const std::string &linePrefix, //===----------------------------------------------------------------------===// ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType) -: classType(classType) { + : classType(classType) { attributeAnnotations.resize(classType->getAttributes().size()); methodAnnotations.resize(classType->methods().size()); } -std::vector & -ClassAnnotation::getAttributeAnnotations() { +std::vector &ClassAnnotation::getAttributeAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // @@ -53,8 +52,7 @@ ClassAnnotation::getAttributeAnnotations() { return attributeAnnotations; } -std::vector & -ClassAnnotation::getMethodAnnotations() { +std::vector &ClassAnnotation::getMethodAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // @@ -80,7 +78,8 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator, methodAnnotation.isExported = false; } for (auto &classAttribute : classType->getAttributes()) { - if (auto childClassType = classAttribute.getType()->cast()) { + if (auto childClassType = + classAttribute.getType()->cast()) { exportNoneRecurse(classAnnotator, childClassType.get()); } } @@ -107,7 +106,7 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType, ss << "class '" << classType->name()->qualifiedName() << "' does not have a method or attribute called '" << exportedPath.back() << "'"; - throw std::invalid_argument(ss.str()); + throw std::invalid_argument(ss.str()); } ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); std::vector &attributeAnnotations = @@ -198,10 +197,9 @@ void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, throw std::invalid_argument("Empty annotated path. Can only annotate " "shapes/dtypes of a method of a class."); } - c10::ClassType *classType = - getClassAtPath(&rootClassType, c10::ArrayRef(path) - .slice(0, path.size() - 1) - .vec()); + c10::ClassType *classType = getClassAtPath( + &rootClassType, + c10::ArrayRef(path).slice(0, path.size() - 1).vec()); // Throw error if no method on the class of the specified name. torch::jit::Function *function = &classType->getMethod(path.back()); @@ -265,26 +263,26 @@ std::string AttributeAnnotation::toString(const std::string &name) { } std::string ArgAnnotation::toString(int argIndex) { - std::stringstream ss; - ss << "ArgAnnotation(" << argIndex << ") {\n"; - ss << " dtype = " << (dtype ? c10::toString(*dtype) : "") << "\n"; - ss << " shape = "; - if (shape) { - ss << "["; - for (int i = 0, e = shape.value().size(); i != e; i++) { - if (i) { - ss << ", "; - } - ss << shape.value()[i]; + std::stringstream ss; + ss << "ArgAnnotation(" << argIndex << ") {\n"; + ss << " dtype = " << (dtype ? c10::toString(*dtype) : "") << "\n"; + ss << " shape = "; + if (shape) { + ss << "["; + for (int i = 0, e = shape.value().size(); i != e; i++) { + if (i) { + ss << ", "; } - ss << "]\n"; - } else { - ss << "\n"; + ss << shape.value()[i]; } - ss << " hasValueSemantics = " << (hasValueSemantics ? "true" : "false") - << "\n"; - ss << "}\n"; - return ss.str(); + ss << "]\n"; + } else { + ss << "\n"; + } + ss << " hasValueSemantics = " << (hasValueSemantics ? "true" : "false") + << "\n"; + ss << "}\n"; + return ss.str(); } std::string MethodAnnotation::toString(const std::string &name) { @@ -326,7 +324,7 @@ std::string ClassAnnotator::toString() { std::stringstream ss; ss << "ClassAnnotator {\n"; for (auto &p : classAnnotations) { - ss << indentString(" ", p.second->toString()); + ss << indentString(" ", p.second->toString()); } ss << "}\n"; return ss.str(); diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.h b/external/torch-mlir/TorchPlugin/csrc/class_annotator.h similarity index 99% rename from external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.h rename to external/torch-mlir/TorchPlugin/csrc/class_annotator.h index 4ee6f408c..e3dea3679 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/class_annotator.h +++ b/external/torch-mlir/TorchPlugin/csrc/class_annotator.h @@ -21,7 +21,7 @@ #ifndef TORCHMLIRPLUGIN_CSRC_CLASS_ANNOTATOR_H #define TORCHMLIRPLUGIN_CSRC_CLASS_ANNOTATOR_H -#include "../pybind.h" +#include "pybind.h" namespace torch_mlir { diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/function_importer.cpp b/external/torch-mlir/TorchPlugin/csrc/function_importer.cpp similarity index 100% rename from external/torch-mlir/TorchPlugin/csrc/builder/function_importer.cpp rename to external/torch-mlir/TorchPlugin/csrc/function_importer.cpp diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/function_importer.h b/external/torch-mlir/TorchPlugin/csrc/function_importer.h similarity index 94% rename from external/torch-mlir/TorchPlugin/csrc/builder/function_importer.h rename to external/torch-mlir/TorchPlugin/csrc/function_importer.h index 92986b44c..6be79de87 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/function_importer.h +++ b/external/torch-mlir/TorchPlugin/csrc/function_importer.h @@ -10,8 +10,8 @@ #include -#include "../pybind.h" #include "node_importer.h" +#include "pybind.h" #include "mlir-c/IR.h" @@ -40,9 +40,7 @@ namespace torch_mlir { MlirOperation importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, std::function getArgAttribute = - [](int) -> MlirAttribute { - return {nullptr}; - }); + [](int) -> MlirAttribute { return {nullptr}; }); } // namespace torch_mlir diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.cpp b/external/torch-mlir/TorchPlugin/csrc/ivalue_importer.cpp similarity index 99% rename from external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.cpp rename to external/torch-mlir/TorchPlugin/csrc/ivalue_importer.cpp index 1d2c25749..5dd218c29 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.cpp +++ b/external/torch-mlir/TorchPlugin/csrc/ivalue_importer.cpp @@ -19,8 +19,8 @@ #include "mlir-c/Diagnostics.h" #include "torch-mlir-c/TorchTypes.h" -#include "caffe2/core/scope_guard.h" #include "ATen/native/quantized/cpu/packed_params.h" +#include "caffe2/core/scope_guard.h" using namespace torch_mlir; @@ -149,8 +149,7 @@ private: }; } // namespace -MlirValue -IValueImporter::importModule(torch::jit::Module currentModule) { +MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.h b/external/torch-mlir/TorchPlugin/csrc/ivalue_importer.h similarity index 97% rename from external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.h rename to external/torch-mlir/TorchPlugin/csrc/ivalue_importer.h index b2dd4a15c..37459f337 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/ivalue_importer.h +++ b/external/torch-mlir/TorchPlugin/csrc/ivalue_importer.h @@ -10,8 +10,8 @@ #include -#include "../pybind.h" #include "class_annotator.h" +#include "pybind.h" #include "mlir-c/IR.h" diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/mlir_utils.h b/external/torch-mlir/TorchPlugin/csrc/mlir_utils.h similarity index 100% rename from external/torch-mlir/TorchPlugin/csrc/builder/mlir_utils.h rename to external/torch-mlir/TorchPlugin/csrc/mlir_utils.h diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/module_builder.cpp b/external/torch-mlir/TorchPlugin/csrc/module_builder.cpp similarity index 100% rename from external/torch-mlir/TorchPlugin/csrc/builder/module_builder.cpp rename to external/torch-mlir/TorchPlugin/csrc/module_builder.cpp diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/module_builder.h b/external/torch-mlir/TorchPlugin/csrc/module_builder.h similarity index 98% rename from external/torch-mlir/TorchPlugin/csrc/builder/module_builder.h rename to external/torch-mlir/TorchPlugin/csrc/module_builder.h index 1a1e78209..6a3856056 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/module_builder.h +++ b/external/torch-mlir/TorchPlugin/csrc/module_builder.h @@ -8,7 +8,7 @@ #ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_H #define TORCHMLIRPLUGIN_CSRC_BUILDER_H -#include "../pybind.h" +#include "pybind.h" #include "class_annotator.h" diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/node_importer.cpp b/external/torch-mlir/TorchPlugin/csrc/node_importer.cpp similarity index 97% rename from external/torch-mlir/TorchPlugin/csrc/builder/node_importer.cpp rename to external/torch-mlir/TorchPlugin/csrc/node_importer.cpp index d749ac50f..087e21a35 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/node_importer.cpp +++ b/external/torch-mlir/TorchPlugin/csrc/node_importer.cpp @@ -211,12 +211,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()), resultTypes, mlirRegionCreate(), mlirRegionCreate()); mapResults(node, operation); - auto createTerminator = - [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { - createMlirOperationAtEnd( - appendToBlock, "torch.prim.If.yield", loc, - derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); - }; + auto createTerminator = [&](c10::ArrayRef yieldedValues, + MlirBlock appendToBlock) { + createMlirOperationAtEnd( + appendToBlock, "torch.prim.If.yield", loc, + derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); + }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), importBlock(node->blocks()[0], createTerminator)); diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/node_importer.h b/external/torch-mlir/TorchPlugin/csrc/node_importer.h similarity index 97% rename from external/torch-mlir/TorchPlugin/csrc/builder/node_importer.h rename to external/torch-mlir/TorchPlugin/csrc/node_importer.h index 8119efaee..efdab3e9c 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/node_importer.h +++ b/external/torch-mlir/TorchPlugin/csrc/node_importer.h @@ -10,7 +10,7 @@ #include -#include "../pybind.h" +#include "pybind.h" #include "mlir-c/IR.h" diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/python_bindings.cpp b/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp similarity index 97% rename from external/torch-mlir/TorchPlugin/csrc/builder/python_bindings.cpp rename to external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp index 57f9a43bc..0a091ebb7 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/python_bindings.cpp +++ b/external/torch-mlir/TorchPlugin/csrc/python_bindings.cpp @@ -5,13 +5,12 @@ // //===----------------------------------------------------------------------===// -#include "../pybind.h" +// This is the top-level entry point for the MLIR <-> PyTorch bridge. #include -#include "../init_python_bindings.h" -#include "module_builder.h" #include "class_annotator.h" +#include "module_builder.h" using namespace torch_mlir; namespace py = pybind11; @@ -122,7 +121,7 @@ py::list GetRegisteredOps() { } // namespace -void torch_mlir::InitBuilderBindings(py::module &m) { +PYBIND11_MODULE(_torch_mlir, m) { m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring); ModuleBuilder::bind(m); initClassAnnotatorBindings(m); diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/torch_to_mlir_utils.cpp b/external/torch-mlir/TorchPlugin/csrc/torch_to_mlir_utils.cpp similarity index 100% rename from external/torch-mlir/TorchPlugin/csrc/builder/torch_to_mlir_utils.cpp rename to external/torch-mlir/TorchPlugin/csrc/torch_to_mlir_utils.cpp diff --git a/external/torch-mlir/TorchPlugin/csrc/builder/torch_to_mlir_utils.h b/external/torch-mlir/TorchPlugin/csrc/torch_to_mlir_utils.h similarity index 99% rename from external/torch-mlir/TorchPlugin/csrc/builder/torch_to_mlir_utils.h rename to external/torch-mlir/TorchPlugin/csrc/torch_to_mlir_utils.h index fcf659a56..ce8ccfb65 100644 --- a/external/torch-mlir/TorchPlugin/csrc/builder/torch_to_mlir_utils.h +++ b/external/torch-mlir/TorchPlugin/csrc/torch_to_mlir_utils.h @@ -10,7 +10,7 @@ #include -#include "../pybind.h" +#include "pybind.h" #include "mlir-c/IR.h"