diff --git a/projects/jit_ir_common/CMakeLists.txt b/projects/jit_ir_common/CMakeLists.txt index e69de29bb..f0a3ff596 100644 --- a/projects/jit_ir_common/CMakeLists.txt +++ b/projects/jit_ir_common/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc/jit_ir_importer) diff --git a/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt b/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt new file mode 100644 index 000000000..32c03c56f --- /dev/null +++ b/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt @@ -0,0 +1,26 @@ +# Static library with core functionality. +# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) +# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 +add_library(TorchMLIRJITIRImporter STATIC + class_annotator.cpp + function_importer.cpp + node_importer.cpp + ivalue_importer.cpp + torch_to_mlir_utils.cpp + ) +target_link_libraries(TorchMLIRJITIRImporter + TorchMLIRAggregateCAPI + ${TORCH_LIBRARIES} + ) +# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...") +target_include_directories(TorchMLIRJITIRImporter PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. +) +set_target_properties(TorchMLIRJITIRImporter PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + OUTPUT_NAME lib_jit_ir_importer + PREFIX "" + SUFFIX ".a" + CXX_VISIBILITY_PRESET "default" + COMPILE_FLAGS "${TORCH_CXXFLAGS}" + ) diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp similarity index 71% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp index 9f936486f..b144e946b 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp @@ -18,8 +18,8 @@ using namespace torch_mlir; //===----------------------------------------------------------------------===// // Prefix every line of `s` with `linePrefix`. -static std::string -indentString(const std::string& linePrefix, const std::string& s) { +static std::string indentString(const std::string &linePrefix, + const std::string &s) { std::stringstream is(s); std::stringstream os; std::string line; @@ -39,28 +39,26 @@ ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType) methodAnnotations.resize(classType->methods().size()); } -std::vector& ClassAnnotation::getAttributeAnnotations() { +std::vector &ClassAnnotation::getAttributeAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // // We can't easily guard against attributes being removed and // then other attributes being added, or types changed, etc. without // effectively mirroring the entire ClassType. - assert( - attributeAnnotations.size() == classType->getAttributes().size() && - "annotations out of sync. class has been mutated"); + assert(attributeAnnotations.size() == classType->getAttributes().size() && + "annotations out of sync. class has been mutated"); return attributeAnnotations; } -std::vector& ClassAnnotation::getMethodAnnotations() { +std::vector &ClassAnnotation::getMethodAnnotations() { // Halfhearted attempt to ensure consistency if the class type has // been mutated. // // We can't easily guard against methods being removed, added, or changed. - assert( - methodAnnotations.size() == classType->methods().size() && - "annotations out of sync. class has been mutated"); + assert(methodAnnotations.size() == classType->methods().size() && + "annotations out of sync. class has been mutated"); return methodAnnotations; } @@ -69,17 +67,17 @@ std::vector& ClassAnnotation::getMethodAnnotations() { // ClassAnnotator //===----------------------------------------------------------------------===// -static void -exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) { - ClassAnnotation& classAnnotation = +static void exportNoneRecurse(ClassAnnotator &classAnnotator, + c10::ClassType *classType) { + ClassAnnotation &classAnnotation = classAnnotator.getOrCreateClassAnnotation(classType); - for (auto& attributeAnnotation : classAnnotation.getAttributeAnnotations()) { + for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) { attributeAnnotation.isExported = false; } - for (auto& methodAnnotation : classAnnotation.getMethodAnnotations()) { + for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) { methodAnnotation.isExported = false; } - for (auto& classAttribute : classType->getAttributes()) { + for (auto &classAttribute : classType->getAttributes()) { if (auto childClassType = classAttribute.getType()->cast()) { exportNoneRecurse(classAnnotator, childClassType.get()); @@ -87,20 +85,20 @@ exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) { } } -void ClassAnnotator::exportNone(c10::ClassType& rootClassType) { +void ClassAnnotator::exportNone(c10::ClassType &rootClassType) { exportNoneRecurse(*this, &rootClassType); } -void ClassAnnotator::exportPath( - c10::ClassType& rootClassType, std::vector exportedPath) { +void ClassAnnotator::exportPath(c10::ClassType &rootClassType, + std::vector exportedPath) { if (exportedPath.size() == 0) { throw std::invalid_argument( "Empty exported path. Can only export a property of a class."); } - c10::ClassType* classType = getClassAtPath( - &rootClassType, c10::ArrayRef(exportedPath) - .slice(0, exportedPath.size() - 1) - .vec()); + c10::ClassType *classType = + getClassAtPath(&rootClassType, c10::ArrayRef(exportedPath) + .slice(0, exportedPath.size() - 1) + .vec()); if (!classType->findAttribute(exportedPath.back()) && !classType->findMethod(exportedPath.back())) { @@ -110,10 +108,10 @@ void ClassAnnotator::exportPath( << exportedPath.back() << "'"; throw std::invalid_argument(ss.str()); } - ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); - std::vector& attributeAnnotations = + ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); + std::vector &attributeAnnotations = classAnnotation.getAttributeAnnotations(); - const std::vector& classAttributes = + const std::vector &classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { if (classAttributes[i].getName() == exportedPath.back()) { @@ -121,9 +119,9 @@ void ClassAnnotator::exportPath( } } - std::vector& methodAnnotations = + std::vector &methodAnnotations = classAnnotation.getMethodAnnotations(); - const std::vector& methods = classType->methods(); + const std::vector &methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { if (methods[i]->name() == exportedPath.back()) { methodAnnotations[i].isExported = true; @@ -131,12 +129,12 @@ void ClassAnnotator::exportPath( } } -const ClassAnnotationMap& ClassAnnotator::getAnnotationMap() { +const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() { return classAnnotations; } -ClassAnnotation& -ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) { +ClassAnnotation & +ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { auto className = classType->name()->qualifiedName(); auto it = classAnnotations.find(className); if (it == classAnnotations.end()) { @@ -151,39 +149,39 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) { return *it->second; } -static void fillArgAnnotations( - MethodAnnotation& methodAnnotation, - std::vector argAnnotations, torch::jit::Function* function) { +static void fillArgAnnotations(MethodAnnotation &methodAnnotation, + std::vector argAnnotations, + torch::jit::Function *function) { if (argAnnotations.size() != function->num_inputs()) { throw std::invalid_argument("Arg annotations should have one entry per " "function parameter (including self)."); } if (!methodAnnotation.argAnnotations.has_value()) { - methodAnnotation.argAnnotations.emplace( - function->num_inputs(), ArgAnnotation{}); + methodAnnotation.argAnnotations.emplace(function->num_inputs(), + ArgAnnotation{}); } methodAnnotation.argAnnotations = argAnnotations; } -void ClassAnnotator::annotateArgs( - c10::ClassType& rootClassType, std::vector path, - std::vector argAnnotations) { +void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, + std::vector path, + std::vector argAnnotations) { if (path.size() == 0) { throw std::invalid_argument("Empty annotated path. Can only annotate " "shapes/dtypes of a method of a class."); } - c10::ClassType* classType = getClassAtPath( + c10::ClassType *classType = getClassAtPath( &rootClassType, c10::ArrayRef(path).slice(0, path.size() - 1).vec()); // Throw error if no method on the class of the specified name. - torch::jit::Function* function = &classType->getMethod(path.back()); + torch::jit::Function *function = &classType->getMethod(path.back()); - ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); - std::vector& methodAnnotations = + ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); + std::vector &methodAnnotations = classAnnotation.getMethodAnnotations(); - const std::vector& methods = classType->methods(); + const std::vector &methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { if (methods[i]->name() == path.back()) { fillArgAnnotations(methodAnnotations[i], argAnnotations, function); @@ -193,9 +191,9 @@ void ClassAnnotator::annotateArgs( return; } -c10::ClassType* ClassAnnotator::getClassAtPath( - c10::ClassType* rootClassType, std::vector path) { - c10::ClassType* classType = rootClassType; +c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType, + std::vector path) { + c10::ClassType *classType = rootClassType; // Reverse so that pop_back gives us the initial atoms first. std::reverse(path.begin(), path.end()); while (!path.empty()) { @@ -217,8 +215,8 @@ c10::ClassType* ClassAnnotator::getClassAtPath( //===----------------------------------------------------------------------===// // Helper methods //===----------------------------------------------------------------------===// -MethodAnnotation* -ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) { +MethodAnnotation * +ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) { auto it = functionToMethodMap.find(function); if (it == functionToMethodMap.end()) { return nullptr; @@ -230,7 +228,7 @@ ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) { // toString methods //===----------------------------------------------------------------------===// -std::string AttributeAnnotation::toString(const std::string& name) { +std::string AttributeAnnotation::toString(const std::string &name) { std::stringstream ss; ss << "AttributeAnnotation('" << name << "') {\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n"; @@ -261,7 +259,7 @@ std::string ArgAnnotation::toString(int argIndex) { return ss.str(); } -std::string MethodAnnotation::toString(const std::string& name) { +std::string MethodAnnotation::toString(const std::string &name) { std::stringstream ss; ss << "MethodAnnotation('" << name << "') {\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n"; @@ -282,13 +280,13 @@ std::string ClassAnnotation::toString() { std::stringstream ss; ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n"; - const std::vector& classAttributes = + const std::vector &classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { ss << indentString( " ", attributeAnnotations[i].toString(classAttributes[i].getName())); } - const std::vector& methods = classType->methods(); + const std::vector &methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name())); } @@ -299,7 +297,7 @@ std::string ClassAnnotation::toString() { std::string ClassAnnotator::toString() { std::stringstream ss; ss << "ClassAnnotator {\n"; - for (auto& p : classAnnotations) { + for (auto &p : classAnnotations) { ss << indentString(" ", p.second->toString()); } ss << "}\n"; diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h similarity index 88% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h rename to projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h index 11aa4e434..0a0815eab 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h @@ -34,7 +34,7 @@ struct AttributeAnnotation { // can be externally accessed. bool isExported = true; - std::string toString(const std::string& name); + std::string toString(const std::string &name); }; // An annotation of an argument of a method. @@ -80,7 +80,7 @@ struct MethodAnnotation { // large printout of the default ArgAnnotation for every method. c10::optional> argAnnotations; - std::string toString(const std::string& name); + std::string toString(const std::string &name); }; // Annotations on a c10::ClassType. @@ -107,10 +107,10 @@ public: // Get the attribute annotations. // The length and order is the same as `classType->getAttributes()`. - std::vector& getAttributeAnnotations(); + std::vector &getAttributeAnnotations(); // Get the method annotations. // The length and order is the same as `classType->methods()`. - std::vector& getMethodAnnotations(); + std::vector &getMethodAnnotations(); std::string toString(); @@ -141,14 +141,14 @@ public: // For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should // have a submodule `a` and that submodule should have a method or attribute // `b`. - void exportPath( - c10::ClassType& rootClassType, std::vector exportedPath); + void exportPath(c10::ClassType &rootClassType, + std::vector exportedPath); // Mark everything as not-exported. // // This is kind of useless by itself, but together with `exportPath` allows // exporting a subset of known names out of a larger collection of unknown // names. - void exportNone(c10::ClassType& rootClassType); + void exportNone(c10::ClassType &rootClassType); // Annotate shapes and dtypes of the arguments of a method at path `path` from // `rootClassType`. @@ -159,23 +159,23 @@ public: // a "has value semantics" boolean. // These will be put into an `ArgAnnotation` struct -- see there for // precise definitions of the promised semantics of each entry. - void annotateArgs( - c10::ClassType& rootClassType, std::vector path, - std::vector argAnnotations); + void annotateArgs(c10::ClassType &rootClassType, + std::vector path, + std::vector argAnnotations); // The annotations collected so far. - const ClassAnnotationMap& getAnnotationMap(); + const ClassAnnotationMap &getAnnotationMap(); // Get the ClassAnnotation corresponding to `classType`. - ClassAnnotation& getOrCreateClassAnnotation(c10::ClassType* classType); + ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType); // Helper to find the MethodAnnotation corresponding to a // torch::jit::Function, or null if not found. // // Users could in principle scan all annotations to find this, but it's more // efficient to maintain the reverse mapping directly. - MethodAnnotation* - getMethodAnnotationForFunction(torch::jit::Function* function); + MethodAnnotation * + getMethodAnnotationForFunction(torch::jit::Function *function); std::string toString(); @@ -183,11 +183,11 @@ private: // Traverse `path` starting from `rootClassType` to find the ClassType // of a presumed nested submodule. Throw an error if there is no such // submodule. - c10::ClassType* - getClassAtPath(c10::ClassType* rootClassType, std::vector path); + c10::ClassType *getClassAtPath(c10::ClassType *rootClassType, + std::vector path); ClassAnnotationMap classAnnotations; // Reverse mapping used to service getMethodAnnotationForFunction. - std::unordered_map + std::unordered_map functionToMethodMap; }; diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.cpp similarity index 88% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/function_importer.cpp index 31d560a73..4a538fbcb 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.cpp @@ -21,9 +21,9 @@ using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( - MlirContext context, torch::jit::Function* function, + MlirContext context, torch::jit::Function *function, std::function getArgAttribute, - const ImportOptions& importOptions) { + const ImportOptions &importOptions) { // Useful for debugging: // graph->dump(); MlirLocation loc = mlirLocationUnknownGet(context); @@ -63,11 +63,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( } auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { - createMlirOperationAtEnd( - appendToBlock, "func.return", loc, - adjustStaticInformationForValues( - appendToBlock, loc, yieldedValues, resultTypes, - /*userAllowsRefinement=*/false)); + createMlirOperationAtEnd(appendToBlock, "func.return", loc, + adjustStaticInformationForValues( + appendToBlock, loc, yieldedValues, resultTypes, + /*userAllowsRefinement=*/false)); }; MlirBlock block = importBlock( context, torch::jit::toGraphFunction(*function).graph()->block(), diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.h similarity index 94% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/function_importer.h index a211f6c46..626068f76 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/function_importer.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.h @@ -40,10 +40,10 @@ namespace torch_mlir { /// null MlirAttribute is returned, no attribute will be attached to that /// argument. MlirOperation importJitFunctionAsFuncOp( - MlirContext context, torch::jit::Function* function, + MlirContext context, torch::jit::Function *function, std::function getArgAttribute = [](int) -> MlirAttribute { return {nullptr}; }, - const ImportOptions& importOptions = {}); + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options.h b/projects/jit_ir_common/csrc/jit_ir_importer/import_options.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options.h rename to projects/jit_ir_common/csrc/jit_ir_importer/import_options.h diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp similarity index 86% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp index 73321817e..ef02096eb 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp @@ -49,10 +49,10 @@ using namespace torch_mlir; // throw an error on). namespace { struct IValueHasher { - size_t operator()(const c10::IValue& ivalue) const { + size_t operator()(const c10::IValue &ivalue) const { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { - return std::hash()( - static_cast(ivalue.internalToPointer())); + return std::hash()( + static_cast(ivalue.internalToPointer())); } return c10::IValue::hash(ivalue); @@ -65,7 +65,7 @@ struct IValueHasher { // such as when tracing). Can we do better? namespace { struct IValueEq { - bool operator()(const c10::IValue& lhs, const c10::IValue& rhs) const { + bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const { return lhs.isSameIdentity(rhs); } }; @@ -99,9 +99,8 @@ namespace { /// (PyTorch allows this!). class IValueImporter { public: - IValueImporter( - MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator, - const ImportOptions& importOptions) + IValueImporter(MlirBlock importBlock, MlirContext context, + ClassAnnotator &annotator, const ImportOptions &importOptions) : importBlock(importBlock), context(context), annotator(annotator), importOptions(importOptions) {} @@ -111,16 +110,15 @@ private: MlirValue rawImportIValue(c10::IValue ivalue); MlirValue importTensor(c10::IValue ivalue); MlirValue importModule(torch::jit::Module jitModule); - void importMethod( - torch::jit::Function* function, MlirBlock classTypeBody, - const MethodAnnotation& methodAnnotation); - void importClassType(c10::ClassType* classType); - void importCompilationUnit(torch::jit::CompilationUnit* cu); + void importMethod(torch::jit::Function *function, MlirBlock classTypeBody, + const MethodAnnotation &methodAnnotation); + void importClassType(c10::ClassType *classType); + void importCompilationUnit(torch::jit::CompilationUnit *cu); MlirBlock importBlock; MlirContext context; - ClassAnnotator& annotator; - const ImportOptions& importOptions; + ClassAnnotator &annotator; + const ImportOptions &importOptions; // Map tracking already-imported values. std::unordered_map valueMap; @@ -131,16 +129,16 @@ private: // e.g. methods (the function names are meaningful and match with Python's // module hierarchy, with the exception of `__main__` being replaced with // `__torch__`). - torch::jit::CompilationUnit* compilationUnit = nullptr; + torch::jit::CompilationUnit *compilationUnit = nullptr; // Used to detect potentially aliasing tensors. - std::unordered_set seenStorageImpls; + std::unordered_set seenStorageImpls; // The set of ClassType's that have already been imported. // // ClassType's are referenced via their `classType->name()->qualifiedName()` // string (as an MLIR symbol name) so we don't need to keep a map associating // them with the MlirOperation that they import into. - std::unordered_set classTypes; + std::unordered_set classTypes; // The stack of attribute names we have traversed to reach the current IValue. // Used for diagnostics. std::vector attributeNameStack; @@ -192,8 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); - mlirRegionAppendOwnedBlock( - nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); + mlirRegionAppendOwnedBlock(nnModuleRegion, + mlirBlockCreate(0, nullptr, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); InserterGuard inserterGuard(importBlock, nnModule); @@ -201,14 +199,13 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { rootModuleName = moduleTypeName; } - const std::vector& slots = currentModule._ivalue()->slots(); - const std::vector& classAttributes = + const std::vector &slots = currentModule._ivalue()->slots(); + const std::vector &classAttributes = currentModule.type()->getAttributes(); - assert( - slots.size() == classAttributes.size() && - "mismatch between object and type!"); + assert(slots.size() == classAttributes.size() && + "mismatch between object and type!"); for (int i = 0, e = slots.size(); i < e; i++) { - const c10::ClassAttribute& classAttribute = classAttributes[i]; + const c10::ClassAttribute &classAttribute = classAttributes[i]; attributeNameStack.push_back(classAttribute.getName()); MlirValue slotValue = importIValue(slots[i]); // TODO: Is it necessary to track whether an attribute is a "parameter"? @@ -235,7 +232,7 @@ MlirValue IValueImporter::importIValue(c10::IValue ivalue) { } // Reject potentially aliased tensors. if (ivalue.isTensor()) { - c10::StorageImpl* storageImpl = + c10::StorageImpl *storageImpl = ivalue.toTensor().storage().unsafeGetStorageImpl(); if (!seenStorageImpls.insert(storageImpl).second) { std::stringstream msg; @@ -261,8 +258,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirType type = torchMlirTorchBoolTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.bool", loc, type, - toMlirNamedAttribute( - "value", mlirBoolAttrGet(context, ivalue.toBool()))); + toMlirNamedAttribute("value", + mlirBoolAttrGet(context, ivalue.toBool()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isDouble()) { @@ -270,23 +267,23 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.float", loc, type, toMlirNamedAttribute( - "value", mlirFloatAttrDoubleGet( - context, mlirF64TypeGet(context), ivalue.toDouble()))); + "value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), + ivalue.toDouble()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isInt()) { MlirType type = torchMlirTorchIntTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.constant.int", loc, type, - toMlirNamedAttribute( - "value", mlirIntegerAttrGet( - mlirIntegerTypeGet(context, 64), ivalue.toInt()))); + toMlirNamedAttribute("value", + mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), + ivalue.toInt()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isList()) { c10::List list = ivalue.toList(); std::vector elems; - for (const c10::IValue& elem : list) { + for (const c10::IValue &elem : list) { elems.push_back(importIValue(elem)); } MlirOperation operation = createMlirOperationAtEnd( @@ -316,7 +313,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { auto list = ivalue.toTuple()->elements(); std::vector operands; std::vector types; - for (const c10::IValue& elem : list) { + for (const c10::IValue &elem : list) { MlirValue operand = importIValue(elem); operands.push_back(operand); types.push_back(mlirValueGetType(operand)); @@ -339,14 +336,14 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { torchMlirTorchStringTypeGet(context), toMlirNamedAttribute( "value", - mlirStringAttrGet( - context, toMlirStringRef(ivalue.toString()->string())))); + mlirStringAttrGet(context, + toMlirStringRef(ivalue.toString()->string())))); return mlirOperationGetResult(operation, 0); } if (ivalue.isNone()) { - MlirOperation operation = createMlirOperationAtEnd( - importBlock, "torch.constant.none", loc, - torchMlirTorchNoneTypeGet(context)); + MlirOperation operation = + createMlirOperationAtEnd(importBlock, "torch.constant.none", loc, + torchMlirTorchNoneTypeGet(context)); return mlirOperationGetResult(operation, 0); } if (ivalue.isCustomClass()) { @@ -440,12 +437,12 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { return tensorValue; } -void IValueImporter::importMethod( - torch::jit::Function* function, MlirBlock classTypeBody, - const MethodAnnotation& methodAnnotation) { +void IValueImporter::importMethod(torch::jit::Function *function, + MlirBlock classTypeBody, + const MethodAnnotation &methodAnnotation) { // The function's name becomes the MLIR symbol table name of the imported func // when we import the compilation unit. - const std::string& symName = function->qualname().qualifiedName(); + const std::string &symName = function->qualname().qualifiedName(); MlirAttribute functionSymbolRef = mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)); @@ -461,7 +458,7 @@ void IValueImporter::importMethod( toMlirNamedAttribute("function", functionSymbolRef), isPrivate); } -void IValueImporter::importClassType(c10::ClassType* classType) { +void IValueImporter::importClassType(c10::ClassType *classType) { if (!classTypes.insert(classType).second) { return; } @@ -479,13 +476,13 @@ void IValueImporter::importClassType(c10::ClassType* classType) { mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr)); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); - ClassAnnotation& classAnnotation = + ClassAnnotation &classAnnotation = annotator.getOrCreateClassAnnotation(classType); - const auto& attributeAnnotations = classAnnotation.getAttributeAnnotations(); - const auto& classAttributes = classType->getAttributes(); + const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations(); + const auto &classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { - const c10::ClassAttribute& classAttribute = classAttributes[i]; + const c10::ClassAttribute &classAttribute = classAttributes[i]; c10::optional isPrivate; if (!attributeAnnotations[i].isExported) { isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); @@ -501,8 +498,8 @@ void IValueImporter::importClassType(c10::ClassType* classType) { isPrivate); } - const auto& methodAnnotations = classAnnotation.getMethodAnnotations(); - const auto& methods = classType->methods(); + const auto &methodAnnotations = classAnnotation.getMethodAnnotations(); + const auto &methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { importMethod(methods[i], classTypeBody, methodAnnotations[i]); } @@ -510,7 +507,7 @@ void IValueImporter::importClassType(c10::ClassType* classType) { createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); } -void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { +void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { if (compilationUnit == nullptr) { compilationUnit = cu; } else { @@ -529,14 +526,14 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { return; } - for (torch::jit::Function* function : cu->get_functions()) { + for (torch::jit::Function *function : cu->get_functions()) { // Useful for debugging errors in free functions that end up being // unused. These can be missing when round-tripping through the on-disk // format, even though they still cause import issues when importing // through the larger Python session where they originate. // std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n"; // std::cerr << *torch::jit::toGraphFunction(function).graph(); - MethodAnnotation* annotation = + MethodAnnotation *annotation = annotator.getMethodAnnotationForFunction(function); MlirOperation func = importJitFunctionAsFuncOp( context, function, @@ -544,9 +541,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { if (!annotation || !annotation->argAnnotations.has_value()) { return {nullptr}; } - c10::optional>& maybeShape = + c10::optional> &maybeShape = annotation->argAnnotations.value()[argIndex].shape; - c10::optional& maybeDtype = + c10::optional &maybeDtype = annotation->argAnnotations.value()[argIndex].dtype; bool hasValueSemantics = annotation->argAnnotations.value()[argIndex].hasValueSemantics; @@ -566,10 +563,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { // the C API constructor, when we want the "we know we have 0 sizes" // case. So use a dummy data pointer. int64_t dummy; - int64_t* shapeData = shape.size() == 0 ? &dummy : shape.data(); + int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data(); if (hasValueSemantics) { - typeBound = torchMlirTorchValueTensorTypeGet( - context, shape.size(), shapeData, dtype); + typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(), + shapeData, dtype); } else { typeBound = torchMlirTorchNonValueTensorTypeGet( context, shape.size(), shapeData, dtype); @@ -597,9 +594,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { } } -MlirValue torch_mlir::importIValue( - c10::IValue ivalue, MlirBlock block, MlirContext context, - ClassAnnotator& annotator, const ImportOptions& importOptions) { +MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, + MlirContext context, + ClassAnnotator &annotator, + const ImportOptions &importOptions) { // When debugging module importing, it can be useful to dump as so: // if (ivalue.isModule()) // ivalue.toModule().dump(true, false, false); diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h similarity index 83% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h index ae3deb945..7cbc7ece8 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/ivalue_importer.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h @@ -25,9 +25,9 @@ namespace torch_mlir { /// Main entry-point for importing torch IValue's . /// Recursively imports `ivalue`, inserting operations at the end of `block`. -MlirValue importIValue( - c10::IValue ivalue, MlirBlock block, MlirContext context, - ClassAnnotator& annotator, const ImportOptions& importOptions); +MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, + ClassAnnotator &annotator, + const ImportOptions &importOptions); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h b/projects/jit_ir_common/csrc/jit_ir_importer/mlir_utils.h similarity index 54% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h rename to projects/jit_ir_common/csrc/jit_ir_importer/mlir_utils.h index 1e033f0d8..97ce5fa10 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/mlir_utils.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/mlir_utils.h @@ -22,92 +22,92 @@ namespace torch_mlir { -inline MlirStringRef toMlirStringRef(const std::string& s) { +inline MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -inline MlirStringRef toMlirStringRef(const char* s) { +inline MlirStringRef toMlirStringRef(const char *s) { return mlirStringRefCreate(s, std::strlen(s)); } -inline MlirNamedAttribute -toMlirNamedAttribute(const char* s, MlirAttribute attr) { +inline MlirNamedAttribute toMlirNamedAttribute(const char *s, + MlirAttribute attr) { MlirContext context = mlirAttributeGetContext(attr); MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); return mlirNamedAttributeGet(ident, attr); } -inline void addToMlirOperationState( - MlirOperationState& state, MlirNamedAttribute namedAttr) { +inline void addToMlirOperationState(MlirOperationState &state, + MlirNamedAttribute namedAttr) { mlirOperationStateAddAttributes(&state, 1, &namedAttr); } -inline void -addToMlirOperationState(MlirOperationState& state, MlirRegion region) { +inline void addToMlirOperationState(MlirOperationState &state, + MlirRegion region) { mlirOperationStateAddOwnedRegions(&state, 1, ®ion); } -inline void -addToMlirOperationState(MlirOperationState& state, MlirValue value) { +inline void addToMlirOperationState(MlirOperationState &state, + MlirValue value) { mlirOperationStateAddOperands(&state, 1, &value); } -inline void addToMlirOperationState( - MlirOperationState& state, const std::vector& values) { +inline void addToMlirOperationState(MlirOperationState &state, + const std::vector &values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } -inline void addToMlirOperationState( - MlirOperationState& state, c10::ArrayRef values) { +inline void addToMlirOperationState(MlirOperationState &state, + c10::ArrayRef values) { mlirOperationStateAddOperands(&state, values.size(), values.data()); } -inline void -addToMlirOperationState(MlirOperationState& state, MlirType resultType) { +inline void addToMlirOperationState(MlirOperationState &state, + MlirType resultType) { mlirOperationStateAddResults(&state, 1, &resultType); } -inline void addToMlirOperationState( - MlirOperationState& state, const std::vector& resultTypes) { +inline void addToMlirOperationState(MlirOperationState &state, + const std::vector &resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } -inline void addToMlirOperationState( - MlirOperationState& state, c10::ArrayRef resultTypes) { +inline void addToMlirOperationState(MlirOperationState &state, + c10::ArrayRef resultTypes) { mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); } template -void addToMlirOperationState(MlirOperationState& state, c10::optional o) { +void addToMlirOperationState(MlirOperationState &state, c10::optional o) { if (o.has_value()) { addToMlirOperationState(state, o.value()); } } -inline void addToMlirOperationState(MlirOperationState& state) {} +inline void addToMlirOperationState(MlirOperationState &state) {} template -void addToMlirOperationState( - MlirOperationState& state, T&& t, U&& u, Ts&&... ts) { +void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, + Ts &&...ts) { addToMlirOperationState(state, std::forward(t)); addToMlirOperationState(state, std::forward(u), std::forward(ts)...); } template -MlirOperation -createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) { +MlirOperation createMlirOperation(std::string name, MlirLocation loc, + Ts &&...ts) { MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); addToMlirOperationState(state, std::forward(ts)...); return mlirOperationCreate(&state); } template -MlirOperation createMlirOperationAtEnd( - MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) { +MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, + MlirLocation loc, Ts &&...ts) { MlirOperation operation = createMlirOperation(name, loc, std::forward(ts)...); - mlirBlockInsertOwnedOperationBefore( - block, mlirBlockGetTerminator(block), operation); + mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), + operation); return operation; } diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp similarity index 65% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp index e9be84acc..0bb4722fc 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp @@ -33,42 +33,40 @@ class NodeImporter { public: NodeImporter(MlirContext context) : context(context) {} - void importNode( - Node* node, MlirBlock appendToBlock, - const ImportOptions& importOptions = {}); + void importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions = {}); MlirBlock importBlock( - Block* jitBlock, CreateTerminatorFn createTerminator, + Block *jitBlock, CreateTerminatorFn createTerminator, c10::optional> blockArgTypes = c10::nullopt, - const ImportOptions& importOptions = {}); + const ImportOptions &importOptions = {}); private: - MlirBlock createBlockFor( - Block* jitBlock, c10::optional> blockArgTypes, - const ImportOptions& importOptions = {}); - void mapValue(Value* jitValue, MlirValue value); - void mapResults(Node* node, MlirOperation operation); - MlirValue lookupMappedValue(Value* jitValue); - std::vector lookupMappedValues(c10::ArrayRef values); + MlirBlock createBlockFor(Block *jitBlock, + c10::optional> blockArgTypes, + const ImportOptions &importOptions = {}); + void mapValue(Value *jitValue, MlirValue value); + void mapResults(Node *node, MlirOperation operation); + MlirValue lookupMappedValue(Value *jitValue); + std::vector lookupMappedValues(c10::ArrayRef values); MlirContext context; - std::unordered_map valueMap; + std::unordered_map valueMap; }; } // namespace using InputsTransformFn = - std::function(std::vector&)>; + std::function(std::vector &)>; // The inputs of `DictConstruct` in TorchScript IR are in the order // like k0, v0, k1, v1. Rearrange them to put the key operands together and // then the value operands like k0, k1,v0, v1. This is the expected format by // the corresponding MLIR op. static std::vector -rearrangeDictConstructInputs(std::vector& inputs) { +rearrangeDictConstructInputs(std::vector &inputs) { if (inputs.empty()) return inputs; - assert( - inputs.size() % 2 == 0 && - "DictConstruct must have even number of operands"); + assert(inputs.size() % 2 == 0 && + "DictConstruct must have even number of operands"); std::vector rearranged; std::vector values; @@ -80,12 +78,12 @@ rearrangeDictConstructInputs(std::vector& inputs) { return rearranged; } -void NodeImporter::importNode( - Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) { +void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions) { MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); - auto createAndMapTrivialNode = [&](Node* node, const std::string& opName, + auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); MlirOperation operation = createMlirOperationAtEnd( @@ -96,7 +94,7 @@ void NodeImporter::importNode( }; auto createAndMapNodeWithAttribute = - [&](Node* node, const std::string& opName, const std::string& attrName, + [&](Node *node, const std::string &opName, const std::string &attrName, MlirAttribute attr) { MlirOperation operation = createMlirOperationAtEnd( appendToBlock, opName, loc, @@ -133,27 +131,27 @@ void NodeImporter::importNode( // ListConstruct and DictConstruct too. auto containedTypes = c10::fmap( node->output()->type()->cast()->containedTypes(), - [&](const c10::TypePtr& t) { + [&](const c10::TypePtr &t) { MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); if (mlirTypeIsNull(type)) { throw mlir_diagnostic_emitted(); } return type; }); - createAndMapTrivialNode( - node, "torch.prim." + std::string(kind.toUnqualString()), - [&](std::vector& inputs) { - assert(containedTypes.size() == inputs.size()); - return adjustStaticInformationForValues( - appendToBlock, loc, inputs, containedTypes, - /*userAllowsRefinement=*/true); - }); + createAndMapTrivialNode(node, + "torch.prim." + std::string(kind.toUnqualString()), + [&](std::vector &inputs) { + assert(containedTypes.size() == inputs.size()); + return adjustStaticInformationForValues( + appendToBlock, loc, inputs, containedTypes, + /*userAllowsRefinement=*/true); + }); return; } case c10::prim::DictConstruct: { - createAndMapTrivialNode( - node, "torch.prim." + std::string(kind.toUnqualString()), - rearrangeDictConstructInputs); + createAndMapTrivialNode(node, + "torch.prim." + std::string(kind.toUnqualString()), + rearrangeDictConstructInputs); return; } case c10::prim::Load: @@ -171,34 +169,32 @@ void NodeImporter::importNode( auto output = node->output(); MlirOperation op; if (output->type()->cast()) { - op = createMlirOperation( - "torch.constant.none", loc, torchMlirTorchNoneTypeGet(context)); + op = createMlirOperation("torch.constant.none", loc, + torchMlirTorchNoneTypeGet(context)); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context), toMlirNamedAttribute( - "value", - mlirBoolAttrGet( - context, static_cast(node->i(c10::attr::value))))); + "value", mlirBoolAttrGet(context, static_cast(node->i( + c10::attr::value))))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.int", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), - toMlirNamedAttribute( - "value", importAttribute(loc, node, c10::attr::value))); + toMlirNamedAttribute("value", + importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.float", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), - toMlirNamedAttribute( - "value", importAttribute(loc, node, c10::attr::value))); + toMlirNamedAttribute("value", + importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.str", loc, torchMlirTorchStringTypeGet(context), toMlirNamedAttribute( - "value", - mlirStringAttrGet( - context, toMlirStringRef(node->s(c10::attr::value))))); + "value", mlirStringAttrGet(context, toMlirStringRef(node->s( + c10::attr::value))))); } else if (output->type()->cast()) { MlirAttribute attr = importAttribute(loc, node, c10::attr::value); if (importOptions.assumeTensorsHaveValueSemantics) { @@ -217,26 +213,24 @@ void NodeImporter::importNode( "torch.constant.device", loc, getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute( - "value", - mlirStringAttrGet( - context, toMlirStringRef(node->s(c10::attr::value))))); + "value", mlirStringAttrGet(context, toMlirStringRef(node->s( + c10::attr::value))))); } else if (auto functionType = output->type()->cast()) { - torch::jit::Function* function = functionType->function(); - const std::string& symName = function->qualname().qualifiedName(); + torch::jit::Function *function = functionType->function(); + const std::string &symName = function->qualname().qualifiedName(); op = createMlirOperation( "func.constant", loc, - getFunctionTypeFromSchema( - context, function->getSchema(), importOptions), + getFunctionTypeFromSchema(context, function->getSchema(), + importOptions), toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); - } else if ( - output->type()->cast() || - output->type()->cast()) { + } else if (output->type()->cast() || + output->type()->cast()) { ClassAnnotator dummyAnnotator; - MlirValue listOrTupleValue = importIValue( - node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator, - importOptions); + MlirValue listOrTupleValue = + importIValue(node->ival(c10::attr::value), appendToBlock, context, + dummyAnnotator, importOptions); mapResults(node, mlirOpResultGetOwner(listOrTupleValue)); return; // Early return, since `importIValue` already added op to block. } else { @@ -264,20 +258,19 @@ void NodeImporter::importNode( mapResults(node, operation); std::vector terminatorOperandTypes = { torchMlirTorchBoolTypeGet(context)}; - terminatorOperandTypes.insert( - terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end()); + terminatorOperandTypes.insert(terminatorOperandTypes.end(), + resultTypes.begin(), resultTypes.end()); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop.condition", loc, - adjustStaticInformationForValues( - appendToBlock, loc, yieldedValues, terminatorOperandTypes, - /*userAllowsRefinement=*/false)); + adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, + terminatorOperandTypes, + /*userAllowsRefinement=*/false)); }; - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 0), - importBlock( - node->blocks()[0], createTerminator, c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0), + importBlock(node->blocks()[0], createTerminator, + c10::nullopt, importOptions)); return; } @@ -292,25 +285,23 @@ void NodeImporter::importNode( MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.If.yield", loc, - adjustStaticInformationForValues( - appendToBlock, loc, yieldedValues, resultTypes, - /*userAllowsRefinement=*/false)); + adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, + resultTypes, + /*userAllowsRefinement=*/false)); }; - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 0), - importBlock( - node->blocks()[0], createTerminator, c10::nullopt, importOptions)); - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 1), - importBlock( - node->blocks()[1], createTerminator, c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0), + importBlock(node->blocks()[0], createTerminator, + c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1), + importBlock(node->blocks()[1], createTerminator, + c10::nullopt, importOptions)); return; } if (kind == c10::prim::CallMethod) { auto classType = node->input(0)->type()->cast(); auto methodName = node->s(c10::attr::name); - torch::jit::Function* function = classType->findMethod(methodName); + torch::jit::Function *function = classType->findMethod(methodName); MlirType calleeType = getFunctionTypeFromSchema( context, function->getSchema(), importOptions); std::vector expectedTypes; @@ -323,17 +314,17 @@ void NodeImporter::importNode( adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs()), expectedTypes, /*userAllowsRefinement=*/false), - toMlirNamedAttribute( - "name", importAttribute(loc, node, c10::attr::name))); + toMlirNamedAttribute("name", + importAttribute(loc, node, c10::attr::name))); mapResults(node, operation); return; } if (kind == c10::prim::CallFunction) { auto functionType = node->input(0)->type()->cast(); - torch::jit::Block* calleeEntryBlock = + torch::jit::Block *calleeEntryBlock = torch::jit::toGraphFunction(*functionType->function()).graph()->block(); - auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value* v) { + auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { return getMlirTypeFromTorchType(loc, v->type(), importOptions); }); std::string functionName = node->input(0)->node()->s(c10::attr::name); @@ -348,9 +339,9 @@ void NodeImporter::importNode( // promoted result dtype for a PyTorch computation. Here we turn the call to // this function to the torch dialect equivalent op `torch.promote_dtypes`. if (functionName == "__torch_mlir_internal_promote_dtypes") { - operation = createMlirOperationAtEnd( - appendToBlock, "torch.promote_dtypes", loc, resultTypes, - adjustedFuncArgs); + operation = + createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc, + resultTypes, adjustedFuncArgs); } else { operation = createMlirOperationAtEnd( appendToBlock, "func.call_indirect", loc, resultTypes, @@ -369,23 +360,23 @@ void NodeImporter::importNode( } } -MlirBlock NodeImporter::importBlock( - Block* jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes, - const ImportOptions& importOptions) { +MlirBlock +NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); - for (Node* node : jitBlock->nodes()) { + for (Node *node : jitBlock->nodes()) { importNode(node, block, importOptions); } - Node* returnNode = jitBlock->return_node(); + Node *returnNode = jitBlock->return_node(); createTerminator(lookupMappedValues(returnNode->inputs()), block); return block; } MlirBlock NodeImporter::createBlockFor( - Block* jitBlock, c10::optional> blockArgTypes, - const ImportOptions& importOptions) { - Node* paramNode = jitBlock->param_node(); + Block *jitBlock, c10::optional> blockArgTypes, + const ImportOptions &importOptions) { + Node *paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); std::vector paramNodeTypes = getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); @@ -394,11 +385,11 @@ MlirBlock NodeImporter::createBlockFor( else assert(blockArgTypes->size() == paramNodeTypes.size()); std::vector blockArgLocs(paramNodeTypes.size(), loc); - MlirBlock block = mlirBlockCreate( - blockArgTypes.value().size(), blockArgTypes.value().data(), - blockArgLocs.data()); + MlirBlock block = + mlirBlockCreate(blockArgTypes.value().size(), + blockArgTypes.value().data(), blockArgLocs.data()); for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { - Value* jitValue = paramNode->outputs()[i]; + Value *jitValue = paramNode->outputs()[i]; MlirValue value = mlirBlockGetArgument(block, i); MlirValue adjusted = adjustStaticInformationForValues( block, loc, {value}, {paramNodeTypes[i]}, @@ -408,40 +399,40 @@ MlirBlock NodeImporter::createBlockFor( return block; } -void NodeImporter::mapValue(Value* jitValue, MlirValue value) { +void NodeImporter::mapValue(Value *jitValue, MlirValue value) { auto it = valueMap.find(jitValue); (void)it; assert(it == valueMap.end() && "jitValue has already been mapped"); valueMap[jitValue] = value; } -void NodeImporter::mapResults(Node* node, MlirOperation operation) { - assert( - node->outputs().size() == (size_t)mlirOperationGetNumResults(operation)); +void NodeImporter::mapResults(Node *node, MlirOperation operation) { + assert(node->outputs().size() == + (size_t)mlirOperationGetNumResults(operation)); for (int i = 0, e = node->outputs().size(); i < e; i++) { mapValue(node->outputs()[i], mlirOperationGetResult(operation, i)); } } -MlirValue NodeImporter::lookupMappedValue(Value* jitValue) { +MlirValue NodeImporter::lookupMappedValue(Value *jitValue) { auto it = valueMap.find(jitValue); - assert( - it != valueMap.end() && - "trying to get mapping for jitValue that is not mapped yet!"); + assert(it != valueMap.end() && + "trying to get mapping for jitValue that is not mapped yet!"); return it->second; } std::vector -NodeImporter::lookupMappedValues(c10::ArrayRef values) { +NodeImporter::lookupMappedValues(c10::ArrayRef values) { std::vector ret; - for (Value* value : values) { + for (Value *value : values) { ret.push_back(lookupMappedValue(value)); } return ret; } -MlirBlock torch_mlir::importBlock( - MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes, - const ImportOptions& importOptions) { +MlirBlock +torch_mlir::importBlock(MlirContext context, Block *jitBlock, + CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { NodeImporter importer(context); - return importer.importBlock( - jitBlock, createTerminator, blockArgTypes, importOptions); + return importer.importBlock(jitBlock, createTerminator, blockArgTypes, + importOptions); } diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h similarity index 85% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h index f36352058..7fce8b988 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/node_importer.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h @@ -36,11 +36,11 @@ using CreateTerminatorFn = /// are required to be for correctness. The code will internally attempt to /// adjust the types to the block argument types. /// TODO: Formalize what type conversions are allowed here. -MlirBlock importBlock( - MlirContext context, torch::jit::Block* jitBlock, - CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt, - const ImportOptions& importOptions = {}); +MlirBlock +importBlock(MlirContext context, torch::jit::Block *jitBlock, + CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.cpp similarity index 74% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.cpp index fc8858734..afac7b164 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.cpp @@ -26,8 +26,8 @@ using namespace torch_mlir; -static MlirType getMlirTypeForTorchScalarTypeRaw( - MlirContext context, c10::ScalarType scalarType) { +static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, + c10::ScalarType scalarType) { using c10::ScalarType; switch (scalarType) { case ScalarType::Byte: @@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw( } } -MlirType torch_mlir::getMlirTypeForTorchScalarType( - MlirLocation loc, c10::ScalarType scalarType) { +MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc, + c10::ScalarType scalarType) { auto type = getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType); if (mlirTypeIsNull(type)) { @@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType( // There is no generic way to import custom classes (or their types), so we // have to name match them here (and the relevant code in the ivalue // importer) and create special IR constructs for them. -static MlirType mapCustomClassType( - MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) { +static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, + const c10::ClassTypePtr &classType) { // If the type is unnamed, it cannot be a custom class. if (!classType->name().has_value()) { return {nullptr}; @@ -126,9 +126,10 @@ static MlirType mapCustomClassType( throw mlir_diagnostic_emitted(); } -MlirType torch_mlir::getMlirTypeFromTorchType( - MlirLocation loc, const c10::TypePtr& torchType, - const ImportOptions& importOptions) { +MlirType +torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, + const c10::TypePtr &torchType, + const ImportOptions &importOptions) { MlirContext context = mlirLocationGetContext(loc); using c10::TypeKind; auto kind = torchType->kind(); @@ -140,11 +141,10 @@ MlirType torch_mlir::getMlirTypeFromTorchType( : torchMlirTorchNonValueTensorTypeGet; if (importOptions.ignoreExistingTensorShapesAndDtypes) { - return getMlirTensorType( - context, - /*numSizes=*/-1, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/{nullptr}); + return getMlirTensorType(context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/{nullptr}); } // Element type. @@ -156,18 +156,17 @@ MlirType torch_mlir::getMlirTypeFromTorchType( return {nullptr}; } // Sizes. - auto& sizes = tensorType->symbolic_sizes(); + auto &sizes = tensorType->symbolic_sizes(); if (!sizes.rank()) { // Unranked. - return getMlirTensorType( - context, - /*numSizes=*/-1, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/ - elementType); + return getMlirTensorType(context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/ + elementType); } // Ranked with possibly dynamic dims. - auto& symbolicShape = tensorType->symbolic_sizes(); + auto &symbolicShape = tensorType->symbolic_sizes(); std::vector dims; dims.resize(*sizes.rank()); for (size_t i = 0; i < dims.size(); ++i) { @@ -180,12 +179,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType( // the C API constructor, when we want the "we know we have 0 sizes" // case. So use a dummy data pointer. int64_t dummy; - int64_t* dimsData = dims.size() == 0 ? &dummy : dims.data(); - return getMlirTensorType( - context, dims.size(), - /*optionalSizes=*/dimsData, - /*optionalDtype=*/ - elementType); + int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data(); + return getMlirTensorType(context, dims.size(), + /*optionalSizes=*/dimsData, + /*optionalDtype=*/ + elementType); } case TypeKind::IntType: { return torchMlirTorchIntTypeGet(context); @@ -209,22 +207,22 @@ MlirType torch_mlir::getMlirTypeFromTorchType( } case TypeKind::TupleType: { std::vector containedTypes; - for (const c10::TypePtr& type : + for (const c10::TypePtr &type : torchType->cast()->containedTypes()) { containedTypes.push_back( getMlirTypeFromTorchType(loc, type, importOptions)); } - return torchMlirTorchTupleTypeGet( - context, containedTypes.size(), containedTypes.data()); + return torchMlirTorchTupleTypeGet(context, containedTypes.size(), + containedTypes.data()); } case TypeKind::UnionType: { std::vector containedTypes; - for (const c10::TypePtr& type : + for (const c10::TypePtr &type : torchType->cast()->containedTypes()) { containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); } - return torchMlirTorchUnionTypeGet( - context, containedTypes.size(), containedTypes.data()); + return torchMlirTorchUnionTypeGet(context, containedTypes.size(), + containedTypes.data()); } case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( @@ -244,7 +242,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType( return torchMlirTorchAnyTypeGet(context); } case TypeKind::ClassType: { - const c10::ClassTypePtr& classType = torchType->cast(); + const c10::ClassTypePtr &classType = torchType->cast(); MlirType customClassType = mapCustomClassType(context, loc, classType); if (!mlirTypeIsNull(customClassType)) { return customClassType; @@ -268,11 +266,12 @@ MlirType torch_mlir::getMlirTypeFromTorchType( } } -MlirType torch_mlir::getFunctionTypeFromSchema( - MlirContext context, const c10::FunctionSchema& schema, - const ImportOptions& importOptions) { +MlirType +torch_mlir::getFunctionTypeFromSchema(MlirContext context, + const c10::FunctionSchema &schema, + const ImportOptions &importOptions) { MlirLocation loc = mlirLocationUnknownGet(context); - auto mapType = [&](const c10::TypePtr& torchType) { + auto mapType = [&](const c10::TypePtr &torchType) { MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); if (mlirTypeIsNull(type)) { std::stringstream msg; @@ -284,20 +283,17 @@ MlirType torch_mlir::getFunctionTypeFromSchema( }; std::vector inputTypes = - c10::fmap(schema.arguments(), [&](const c10::Argument& arg) { - return mapType(arg.type()); - }); + c10::fmap(schema.arguments(), + [&](const c10::Argument &arg) { return mapType(arg.type()); }); std::vector outputTypes = - c10::fmap(schema.returns(), [&](const c10::Argument& arg) { - return mapType(arg.type()); - }); - return mlirFunctionTypeGet( - context, inputTypes.size(), inputTypes.data(), outputTypes.size(), - outputTypes.data()); + c10::fmap(schema.returns(), + [&](const c10::Argument &arg) { return mapType(arg.type()); }); + return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), + outputTypes.size(), outputTypes.data()); } -MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( - at::Tensor tensor, MlirLocation loc) { +MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, + MlirLocation loc) { using at::ScalarType; auto throwUnsupportedTensorError = [&]() { @@ -312,8 +308,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( // The flat number of bytes throws an exception for tensors that are not // dense and accessible as such. - at::checkLayout( - at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided); + at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor, + c10::Layout::Strided); // Construct the ShapedType. @@ -338,47 +334,47 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( switch (tensor.scalar_type()) { case ScalarType::Int: return mlirDenseElementsAttrInt32Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Long: return mlirDenseElementsAttrInt64Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Float: return mlirDenseElementsAttrFloatGet( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Double: return mlirDenseElementsAttrDoubleGet( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); break; case ScalarType::Bool: { // TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed // upstream to take in a `const bool *` rather than a `const int *` to avoid // the unnecessary copying into an array four times as large. - const int8_t* elements = static_cast(tensorData); + const int8_t *elements = static_cast(tensorData); std::vector tensorDataVector(elements, elements + numElements); - return mlirDenseElementsAttrBoolGet( - shapedType, numElements, tensorDataVector.data()); + return mlirDenseElementsAttrBoolGet(shapedType, numElements, + tensorDataVector.data()); } break; case ScalarType::QInt8: return mlirDenseElementsAttrInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::QUInt8: return mlirDenseElementsAttrUInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::BFloat16: return mlirDenseElementsAttrBFloat16Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Half: return mlirDenseElementsAttrFloat16Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Byte: return mlirDenseElementsAttrUInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); case ScalarType::Char: return mlirDenseElementsAttrInt8Get( - shapedType, numElements, static_cast(tensorData)); + shapedType, numElements, static_cast(tensorData)); default: throwUnsupportedTensorError(); @@ -386,8 +382,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( return {nullptr}; // Unreachable. } -MlirAttribute torch_mlir::importAttribute( - MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) { +MlirAttribute torch_mlir::importAttribute(MlirLocation loc, + torch::jit::Node *node, + c10::Symbol symbol) { MlirContext context = mlirLocationGetContext(loc); auto kind = node->kindOf(symbol); switch (kind) { @@ -396,8 +393,8 @@ MlirAttribute torch_mlir::importAttribute( // do that. return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol)); case torch::jit::AttributeKind::f: - return mlirFloatAttrDoubleGet( - context, mlirF64TypeGet(context), node->f(symbol)); + return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), + node->f(symbol)); case torch::jit::AttributeKind::s: return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol))); case torch::jit::AttributeKind::t: @@ -411,23 +408,23 @@ MlirAttribute torch_mlir::importAttribute( } } -MlirLocation torch_mlir::getMlirLocationFromNode( - MlirContext context, torch::jit::Node* node) { +MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, + torch::jit::Node *node) { MlirLocation loc = mlirLocationUnknownGet(context); if (node->hasAttribute(c10::Symbol::attr("source_files"))) { - const auto& sourceFiles = node->ss(c10::Symbol::attr("source_files")); - const auto& lineNumbers = node->is(c10::Symbol::attr("line_numbers")); - const auto& functions = node->ss(c10::Symbol::attr("functions")); + const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files")); + const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers")); + const auto &functions = node->ss(c10::Symbol::attr("functions")); // Chain a sequence of calls to construct single MlirLocation. for (const auto i : c10::irange(sourceFiles.size())) { MlirLocation newLoc = mlirLocationNameGet( context, toMlirStringRef(functions[i]), - mlirLocationFileLineColGet( - context, toMlirStringRef(sourceFiles[i]), lineNumbers[i], - 0 /* column is not available */ - )); + mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]), + lineNumbers[i], + 0 /* column is not available */ + )); loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); } if (sourceFiles.size() == 1) { @@ -436,7 +433,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode( loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); } } else if (auto flc = node->sourceRange().file_line_col()) { - const std::string& file = std::get<0>(*flc); + const std::string &file = std::get<0>(*flc); int line = std::get<1>(*flc); int col = std::get<2>(*flc); loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); @@ -448,7 +445,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode( locationName = scopeName; } - if (const c10::FunctionSchema* schema = node->maybeSchema()) { + if (const c10::FunctionSchema *schema = node->maybeSchema()) { if (!locationName.empty()) { locationName += "/"; } @@ -462,9 +459,10 @@ MlirLocation torch_mlir::getMlirLocationFromNode( return loc; } -std::vector torch_mlir::getMlirTypesFromValues( - MlirLocation loc, c10::ArrayRef values, - const ImportOptions& importOptions) { +std::vector +torch_mlir::getMlirTypesFromValues(MlirLocation loc, + c10::ArrayRef values, + const ImportOptions &importOptions) { std::vector ret; for (auto value : values) { MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); @@ -493,24 +491,25 @@ std::vector torch_mlir::adjustStaticInformationForValues( } std::stringstream msg; - MlirStringCallback printToStream = +[](MlirStringRef str, void* userData) { - std::stringstream* stream = static_cast(userData); + MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { + std::stringstream *stream = static_cast(userData); stream->write(str.data, str.length); }; msg << "unhandled: could not adjust static info for type from "; - mlirTypePrint(type, printToStream, static_cast(&msg)); + mlirTypePrint(type, printToStream, static_cast(&msg)); msg << " to type "; - mlirTypePrint(expectedType, printToStream, static_cast(&msg)); + mlirTypePrint(expectedType, printToStream, static_cast(&msg)); mlirEmitError(loc, msg.str().c_str()); throw mlir_diagnostic_emitted(); } return ret; } -MlirOperation torch_mlir::createOperationFromSchema( - MlirBlock appendToBlock, MlirLocation loc, - const c10::FunctionSchema& schema, c10::ArrayRef resultTypes, - c10::ArrayRef operands) { +MlirOperation +torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc, + const c10::FunctionSchema &schema, + c10::ArrayRef resultTypes, + c10::ArrayRef operands) { MlirContext context = mlirLocationGetContext(loc); // Munge the name into the appropriate MLIR operation name. @@ -520,15 +519,15 @@ MlirOperation torch_mlir::createOperationFromSchema( auto separatorPosition = opNameSuffix.find_first_of("::"); assert(separatorPosition != std::string::npos); opNameSuffix.replace(separatorPosition, 2, "."); - const std::string& overloadName = schema.overload_name(); + const std::string &overloadName = schema.overload_name(); if (!overloadName.empty()) { opNameSuffix = opNameSuffix + "." + overloadName; } std::string opName = "torch." + opNameSuffix; // If we have a registered op, use it! if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) { - return createMlirOperationAtEnd( - appendToBlock, opName, loc, resultTypes, operands); + return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes, + operands); } // Oops, no registered op -- create an opaque wrapper so that import can // still succeed. This helps a common use case of filling out registered ops diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.h similarity index 61% rename from projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h rename to projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.h index eea49999b..82f394999 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/torch_to_mlir_utils.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.h @@ -25,7 +25,7 @@ namespace torch_mlir { /// Thrown on failure when details are in MLIR emitted diagnostics. class mlir_diagnostic_emitted : public std::runtime_error { public: - mlir_diagnostic_emitted(const char* what) : std::runtime_error(what) {} + mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} }; @@ -38,36 +38,37 @@ public: /// for Python code). /// /// Returns a null type on failure and emits a diagnostic. -MlirType -getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType); +MlirType getMlirTypeForTorchScalarType(MlirLocation loc, + c10::ScalarType scalarType); /// Maps a torch type to a corresponding MlirType. Returns a null type /// on failure and emits a diagnostic. -MlirType getMlirTypeFromTorchType( - MlirLocation loc, const c10::TypePtr& torchType, - const ImportOptions& importOptions = {}); +MlirType getMlirTypeFromTorchType(MlirLocation loc, + const c10::TypePtr &torchType, + const ImportOptions &importOptions = {}); /// Creates a FunctionType suitable for expressing the signature of `schema`. /// /// This can differ from the type inferred from the block of a /// torch::jit::Function due to derefinement and refinement of tensor types. -MlirType getFunctionTypeFromSchema( - MlirContext context, const c10::FunctionSchema& schema, - const ImportOptions& importOptions = {}); +MlirType getFunctionTypeFromSchema(MlirContext context, + const c10::FunctionSchema &schema, + const ImportOptions &importOptions = {}); /// Creates an appropriate MlirAttribute that holds the same values as `tensor`. -MlirAttribute -convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc); +MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor, + MlirLocation loc); -MlirAttribute -importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol); +MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node, + c10::Symbol symbol); -MlirLocation -getMlirLocationFromNode(MlirContext context, torch::jit::Node* node); +MlirLocation getMlirLocationFromNode(MlirContext context, + torch::jit::Node *node); -std::vector getMlirTypesFromValues( - MlirLocation loc, c10::ArrayRef values, - const ImportOptions& importOptions = {}); +std::vector +getMlirTypesFromValues(MlirLocation loc, + c10::ArrayRef values, + const ImportOptions &importOptions = {}); std::vector adjustStaticInformationForValues( MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values, @@ -78,10 +79,11 @@ std::vector adjustStaticInformationForValues( /// /// The primary difficulty here is doing the appropriate name munging and /// checking if the have a registered op. -MlirOperation createOperationFromSchema( - MlirBlock appendToBlock, MlirLocation loc, - const c10::FunctionSchema& schema, c10::ArrayRef resultTypes, - c10::ArrayRef operands); +MlirOperation createOperationFromSchema(MlirBlock appendToBlock, + MlirLocation loc, + const c10::FunctionSchema &schema, + c10::ArrayRef resultTypes, + c10::ArrayRef operands); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt index bcf1ec89d..5ae5ddf0a 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt @@ -1,30 +1,3 @@ -# Static library with core functionality. -# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) -# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 -add_library(TorchMLIRJITIRImporter STATIC - class_annotator.cpp - function_importer.cpp - node_importer.cpp - ivalue_importer.cpp - torch_to_mlir_utils.cpp - ) -target_link_libraries(TorchMLIRJITIRImporter - TorchMLIRAggregateCAPI - ${TORCH_LIBRARIES} - ) -# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...") -target_include_directories(TorchMLIRJITIRImporter PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/.. -) -set_target_properties(TorchMLIRJITIRImporter PROPERTIES - LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" - OUTPUT_NAME lib_jit_ir_importer - PREFIX "" - SUFFIX ".a" - CXX_VISIBILITY_PRESET "default" - COMPILE_FLAGS "${TORCH_CXXFLAGS}" - ) - # Separate Pybind MODULE due to issues with a SHARED library. # https://github.com/llvm/torch-mlir/issues/1154 add_library(TorchMLIRJITIRImporterPybind MODULE diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp index 2e5296820..c1219d48d 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// #include "class_annotator_pybind.h" -#include "class_annotator.h" +#include "jit_ir_importer/class_annotator.h" #include #include diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp index 3e0183a95..94a47229d 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// #include "import_options_pybind.h" -#include "import_options.h" +#include "jit_ir_importer/import_options.h" namespace py = pybind11; diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp index c1922f8f0..92f131b0d 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp @@ -9,9 +9,9 @@ #include "module_builder.h" -#include "function_importer.h" -#include "ivalue_importer.h" -#include "mlir_utils.h" +#include "jit_ir_importer/function_importer.h" +#include "jit_ir_importer/ivalue_importer.h" +#include "jit_ir_importer/mlir_utils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" diff --git a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h index 273778c41..cff2200d3 100644 --- a/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H #define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H -#include "class_annotator.h" +#include "jit_ir_importer/class_annotator.h" #include "mlir-c/IR.h"