//===- ivalue_importer.cpp ------------------------------------------------===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// #include "ivalue_importer.h" #include "class_annotator.h" #include "function_importer.h" #include #include "mlir_utils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" #include "npcomp-c/Types.h" using namespace torch_mlir; // Hashing functionality for IValue's. // // What we want here is a strict object identity hash. This is different from // what Python usually treats as hashing, which is a deep equality hash. In // Python terms, what we want here is a hash of `id(x)` -- unfortunately, IValue // is not uniformly heap allocated the way a `PyObject*` is, so special handling // is needed. At the time of this writing, there seem to be two different // implementations, neither of which is exactly what we want. // // - c10::IValue::hash static method // - Problem: Doesn't handle certain data types, in particular objects (which // modules are a special case of) and lists/dicts. This makes sense when // reflecting the Python semantics. // - c10::WeakIValue::hash method // - Problem: it literally just returns the bits of the "union" as an int. // This seems to read uninitialized bits for the bool variant. // // We use the `c10::IValue::hash` static method with special cases for data // types that need their identity to be handled specially. `c10::IValue::hash` // seems to be implemented in a principled way following the Python semantics, // which is compatible with the semantics we want (for the subset it doesn't // throw an error on). namespace { struct IValueHasher { size_t operator()(const c10::IValue &ivalue) const { if (ivalue.isObject() || ivalue.isList()) { return std::hash()( static_cast(ivalue.internalToPointer())); } return c10::IValue::hash(ivalue); } }; } // namespace // TODO: The implementation of isSameIdentity looks vulnerable to malloc reusing // the same memory block (if this hash function is used in an online setting, // such as when tracing). Can we do better? namespace { struct IValueEq { bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const { return lhs.isSameIdentity(rhs); } }; } // namespace namespace { /// Helper class for holding state during recursive IValue import. /// /// The intended usage pattern of this class is to construct it then call /// `importIValue`. /// /// The `importIValue` method can be called more than once, and values are /// unified *by object identity*. For types isomorphic to Python builtin types /// the behavior is what you would expect from `id(x)`. /// /// For tensors, object identity is a little tricky. As background, at::Tensor /// basically has 4 parts: /// - at::Tensor which is a smart pointer to at::TensorImpl /// - at::TensorImpl which holds sizes/strides/etc. and points to at::Storage /// - the address of the at::TensorImpl is the identity of the tensor. /// - at::Storage which is a smart pointer to at::StorageImpl /// - at::StorageImpl which is a low-level buffer /// - the address of the at::StorageImpl is the identity of the "storage". /// /// Multiple different tensors can share the same underlying storage. We /// currently import tensors by identity and emit errors in the case of tensors /// with different identity but sharing the same storage. This is done because /// correctly modeling the many ways that tensors can overlap and alias when /// they share storage is difficult. Example hard cases are weird /// strides/offsets that overlap, and even cases where the data types mismatch /// (PyTorch allows this!). class IValueImporter { public: IValueImporter(MlirBlock importBlock, MlirContext context, ClassAnnotator &annotator) : importBlock(importBlock), context(context), typeMapper(context), annotator(annotator) {} MlirValue importIValue(c10::IValue value); private: MlirValue rawImportIValue(c10::IValue value); MlirValue importModule(torch::jit::Module jitModule); void importMethod(torch::jit::Function *function, MlirBlock classTypeBody, const MethodAnnotation &methodAnnotation); void importClassType(c10::ClassType *classType); void importCompilationUnit(torch::jit::CompilationUnit *cu); MlirBlock importBlock; MlirContext context; TypeMapper typeMapper; ClassAnnotator &annotator; // Map tracking already-imported values. std::unordered_map valueMap; // The unique compilation unit that is shared by all modules reachable // from the root ivalue being imported. // It basically contains a symbol table of functions which are referenced from // e.g. methods (the function names are meaningful and match with Python's // module hierarchy, with the exception of `__main__` being replaced with // `__torch__`). torch::jit::CompilationUnit *compilationUnit = nullptr; // Used to detect potentially aliasing tensors. std::unordered_set seenStorageImpls; // The set of ClassType's that have already been imported. // // ClassType's are referenced via their `classType->name()->qualifiedName()` // string (as an MLIR symbol name) so we don't need to keep a map associating // them with the MlirOperation that they import into. std::unordered_set classTypes; // The stack of attribute names we have traversed to reach the current IValue. // Used for diagnostics. std::vector attributeNameStack; // The root module encountered during recursive IValue traversal. // Used for diagnostics. // Note that the "top-level" object being imported can in theory be a list // of modules, so this is populated when our recursive traversal enters a // module not enclosed in any other module, and unset after our recursive // traversal exits the module. c10::optional rootModuleName; }; } // namespace MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); c10::optional maybeName = currentModule.type()->name(); if (!maybeName) { throw std::invalid_argument("cannot import unnamed module"); } std::string moduleTypeName = maybeName->qualifiedName(); // If this is the first time we are encountering a module, import the // compilation unit. importCompilationUnit(currentModule._ivalue()->compilation_unit().get()); // Ensure the class type has been imported. importClassType(currentModule.type().get()); MlirOperation nnModule = createMlirOperation( "torch.nn_module", loc, npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); if (!rootModuleName.has_value()) { rootModuleName = moduleTypeName; } const std::vector &slots = currentModule._ivalue()->slots(); const std::vector &classAttributes = currentModule.type()->getAttributes(); assert(slots.size() == classAttributes.size() && "mismatch between object and type!"); for (int i = 0, e = slots.size(); i < e; i++) { const c10::ClassAttribute &classAttribute = classAttributes[i]; attributeNameStack.push_back(classAttribute.getName()); MlirValue slotValue = importIValue(slots[i]); // TODO: Is it necessary to track whether an attribute is a "parameter"? createMlirOperationAtEnd( nnModuleBody, "torch.slot", loc, slotValue, toMlirNamedAttribute( "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName())))); attributeNameStack.pop_back(); } if (rootModuleName.has_value()) { rootModuleName = c10::nullopt; } createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc); mlirBlockInsertOwnedOperationBefore( importBlock, mlirBlockGetTerminator(importBlock), nnModule); return mlirOperationGetResult(nnModule, 0); } MlirValue IValueImporter::importIValue(c10::IValue ivalue) { auto it = valueMap.find(ivalue); if (it != valueMap.end()) { return it->second; } // Reject potentially aliased tensors. if (ivalue.isTensor()) { c10::StorageImpl *storageImpl = ivalue.toTensor().storage().unsafeGetStorageImpl(); if (!seenStorageImpls.insert(storageImpl).second) { std::stringstream msg; msg << "Unhandled tensor that shares storage with another tensor."; if (rootModuleName) { msg << "\nFound at path '." << c10::QualifiedName(attributeNameStack).qualifiedName() << "' from root object '" << *rootModuleName << "'"; } throw std::invalid_argument(msg.str()); } } MlirValue value = rawImportIValue(ivalue); valueMap[ivalue] = value; return value; } MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); if (ivalue.isBool()) { MlirType type = npcompBoolTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.bool_constant", loc, type, toMlirNamedAttribute("value", mlirBoolAttrGet(context, ivalue.toBool()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isDouble()) { MlirType type = mlirF64TypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.numeric_constant", loc, type, toMlirNamedAttribute( "value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isInt()) { MlirType type = mlirIntegerTypeGet(context, 64); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.numeric_constant", loc, type, toMlirNamedAttribute("value", mlirIntegerAttrGet(type, ivalue.toInt()))); return mlirOperationGetResult(operation, 0); } if (ivalue.isList()) { c10::List list = ivalue.toList(); std::vector elems; for (const c10::IValue &elem : list) { elems.push_back(importIValue(elem)); } MlirOperation operation = createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc, npcompListTypeGet(context), elems); return mlirOperationGetResult(operation, 0); } if (ivalue.isTuple()) { auto list = ivalue.toTuple()->elements(); std::vector elems; for (const c10::IValue &elem : list) { elems.push_back(importIValue(elem)); } MlirOperation operation = createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc, npcompTupleTypeGet(context), elems); return mlirOperationGetResult(operation, 0); } if (ivalue.isTensor()) { at::Tensor tensor = ivalue.toTensor().contiguous(); MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc); MlirOperation constant = createMlirOperationAtEnd( importBlock, "std.constant", loc, mlirAttributeGetType(denseElements), toMlirNamedAttribute("value", denseElements)); MlirOperation ndarray = createMlirOperationAtEnd( importBlock, "numpy.create_array_from_tensor", loc, npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)), mlirOperationGetResult(constant, 0)); return mlirOperationGetResult(ndarray, 0); } if (ivalue.isModule()) { return importModule(ivalue.toModule()); } if (ivalue.isString()) { MlirType type = npcompBytesTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.bytes_constant", loc, type, toMlirNamedAttribute( "value", mlirStringAttrGet(context, toMlirStringRef(ivalue.toString()->string())))); return mlirOperationGetResult(operation, 0); } if (ivalue.isNone()) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context)); return mlirOperationGetResult(operation, 0); } std::stringstream msg; msg << "Unsupported ivalue: " << ivalue; throw std::invalid_argument(msg.str()); } void IValueImporter::importMethod(torch::jit::Function *function, MlirBlock classTypeBody, const MethodAnnotation &methodAnnotation) { // The function's name becomes the MLIR symbol table name of the imported func // when we import the compilation unit. const std::string &symName = function->qualname().qualifiedName(); MlirAttribute functionSymbolRef = mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)); c10::optional isPrivate; if (!methodAnnotation.isExported) { isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); } createMlirOperationAtEnd( classTypeBody, "torch.method", mlirLocationUnknownGet(context), toMlirNamedAttribute( "name", mlirStringAttrGet(context, toMlirStringRef(function->name()))), toMlirNamedAttribute("function", functionSymbolRef), isPrivate); } void IValueImporter::importClassType(c10::ClassType *classType) { if (!classTypes.insert(classType).second) { return; } // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); MlirOperation op = createMlirOperationAtEnd( importBlock, "torch.class_type", loc, mlirRegionCreate(), toMlirNamedAttribute( "sym_name", mlirStringAttrGet( context, toMlirStringRef(classType->name()->qualifiedName())))); MlirRegion region = mlirOperationGetRegion(op, 0); mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr)); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); ClassAnnotation &classAnnotation = annotator.getOrCreateClassAnnotation(classType); const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations(); const auto &classAttributes = classType->getAttributes(); for (int i = 0, e = classAttributes.size(); i != e; i++) { const c10::ClassAttribute &classAttribute = classAttributes[i]; c10::optional isPrivate; if (!attributeAnnotations[i].isExported) { isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); } createMlirOperationAtEnd( classTypeBody, "torch.attr", loc, toMlirNamedAttribute( "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName()))), toMlirNamedAttribute("type", mlirTypeAttrGet(typeMapper.mapFromTorchType( loc, classAttribute.getType()))), isPrivate); } const auto &methodAnnotations = classAnnotation.getMethodAnnotations(); const auto &methods = classType->methods(); for (int i = 0, e = methods.size(); i != e; i++) { importMethod(methods[i], classTypeBody, methodAnnotations[i]); } createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); } void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { if (compilationUnit == nullptr) { compilationUnit = cu; } else { // All sorts of stuff is connected to the compilation unit, such as // c10::ClassType's (owned by the compilation unit), c10::FunctionType // (which holds a pointer to a torch::jit::Function in the compilation // unit), load-bearing symbol table names of functions, etc. // // It doesn't seem to be defined how multiple compilation units semantically // connect with each other, and it doesn't seem to happen either (though // structurally at the C++ level nothing prevents it), so make it an error. if (compilationUnit != cu) { throw std::invalid_argument( "found two compilation units while importing"); } return; } for (torch::jit::Function *function : cu->get_functions()) { MethodAnnotation *annotation = annotator.getMethodAnnotationForFunction(function); MlirOperation func = importJitFunctionAsFuncOp( context, function, [&](int argIndex) -> MlirAttribute { if (!annotation || !annotation->argAnnotations.has_value()) { return {nullptr}; } auto &shape = annotation->argAnnotations.value()[argIndex].shape; auto &dtype = annotation->argAnnotations.value()[argIndex].dtype; // TODO: Handle unranked tensors and tensors with unknown dtype (but // possibly known ranks/sizes). if (!shape || !dtype) { return {nullptr}; } auto typeBound = npcompNdArrayTypeGetRanked( shape->size(), shape->data(), TypeMapper(context).mapFromTorchScalarType( mlirLocationUnknownGet(context), *dtype)); MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute( "torch.type_bound", mlirTypeAttrGet(typeBound)); return mlirDictionaryAttrGet(context, 1, &typeBoundAttr); }); // For IValue importing, the logical linkage structure of the module // is determined by the object graph. // // The functions' symbol names are thus irrelevant to the module's // externally visible characteristics, so mark them all as private. // // These functions may be referenced by the object graph, which can make // them reachable from the exernally visible characteristics of the module, // but they cannot be intrinsically externally visible. mlirOperationSetAttributeByName( func, toMlirStringRef("sym_visibility"), mlirStringAttrGet(context, toMlirStringRef("private"))); mlirBlockInsertOwnedOperationBefore( importBlock, mlirBlockGetTerminator(importBlock), func); } } void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, ClassAnnotator &annotator) { // When debugging module importing, it can be useful to dump as so: // if (ivalue.isModule()) // ivalue.toModule().dump(true, false, false); IValueImporter importer(block, context, annotator); importer.importIValue(ivalue); }