First step of move common jit_ir_importer.

breakup_python_pytorch_deps
Stella Laurenzo 2023-11-18 18:12:15 -08:00
parent 606dc45896
commit f1d9136210
19 changed files with 432 additions and 445 deletions

View File

@ -0,0 +1 @@
add_subdirectory(csrc/jit_ir_importer)

View File

@ -0,0 +1,26 @@
# Static library with core functionality.
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
add_library(TorchMLIRJITIRImporter STATIC
class_annotator.cpp
function_importer.cpp
node_importer.cpp
ivalue_importer.cpp
torch_to_mlir_utils.cpp
)
target_link_libraries(TorchMLIRJITIRImporter
TorchMLIRAggregateCAPI
${TORCH_LIBRARIES}
)
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
target_include_directories(TorchMLIRJITIRImporter PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/..
)
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
OUTPUT_NAME lib_jit_ir_importer
PREFIX ""
SUFFIX ".a"
CXX_VISIBILITY_PRESET "default"
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
)

View File

@ -18,8 +18,8 @@ using namespace torch_mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Prefix every line of `s` with `linePrefix`. // Prefix every line of `s` with `linePrefix`.
static std::string static std::string indentString(const std::string &linePrefix,
indentString(const std::string& linePrefix, const std::string& s) { const std::string &s) {
std::stringstream is(s); std::stringstream is(s);
std::stringstream os; std::stringstream os;
std::string line; std::string line;
@ -39,28 +39,26 @@ ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType)
methodAnnotations.resize(classType->methods().size()); methodAnnotations.resize(classType->methods().size());
} }
std::vector<AttributeAnnotation>& ClassAnnotation::getAttributeAnnotations() { std::vector<AttributeAnnotation> &ClassAnnotation::getAttributeAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has // Halfhearted attempt to ensure consistency if the class type has
// been mutated. // been mutated.
// //
// We can't easily guard against attributes being removed and // We can't easily guard against attributes being removed and
// then other attributes being added, or types changed, etc. without // then other attributes being added, or types changed, etc. without
// effectively mirroring the entire ClassType. // effectively mirroring the entire ClassType.
assert( assert(attributeAnnotations.size() == classType->getAttributes().size() &&
attributeAnnotations.size() == classType->getAttributes().size() && "annotations out of sync. class has been mutated");
"annotations out of sync. class has been mutated");
return attributeAnnotations; return attributeAnnotations;
} }
std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() { std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has // Halfhearted attempt to ensure consistency if the class type has
// been mutated. // been mutated.
// //
// We can't easily guard against methods being removed, added, or changed. // We can't easily guard against methods being removed, added, or changed.
assert( assert(methodAnnotations.size() == classType->methods().size() &&
methodAnnotations.size() == classType->methods().size() && "annotations out of sync. class has been mutated");
"annotations out of sync. class has been mutated");
return methodAnnotations; return methodAnnotations;
} }
@ -69,17 +67,17 @@ std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() {
// ClassAnnotator // ClassAnnotator
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void static void exportNoneRecurse(ClassAnnotator &classAnnotator,
exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) { c10::ClassType *classType) {
ClassAnnotation& classAnnotation = ClassAnnotation &classAnnotation =
classAnnotator.getOrCreateClassAnnotation(classType); classAnnotator.getOrCreateClassAnnotation(classType);
for (auto& attributeAnnotation : classAnnotation.getAttributeAnnotations()) { for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
attributeAnnotation.isExported = false; attributeAnnotation.isExported = false;
} }
for (auto& methodAnnotation : classAnnotation.getMethodAnnotations()) { for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) {
methodAnnotation.isExported = false; methodAnnotation.isExported = false;
} }
for (auto& classAttribute : classType->getAttributes()) { for (auto &classAttribute : classType->getAttributes()) {
if (auto childClassType = if (auto childClassType =
classAttribute.getType()->cast<c10::ClassType>()) { classAttribute.getType()->cast<c10::ClassType>()) {
exportNoneRecurse(classAnnotator, childClassType.get()); exportNoneRecurse(classAnnotator, childClassType.get());
@ -87,20 +85,20 @@ exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) {
} }
} }
void ClassAnnotator::exportNone(c10::ClassType& rootClassType) { void ClassAnnotator::exportNone(c10::ClassType &rootClassType) {
exportNoneRecurse(*this, &rootClassType); exportNoneRecurse(*this, &rootClassType);
} }
void ClassAnnotator::exportPath( void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
c10::ClassType& rootClassType, std::vector<std::string> exportedPath) { std::vector<std::string> exportedPath) {
if (exportedPath.size() == 0) { if (exportedPath.size() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"Empty exported path. Can only export a property of a class."); "Empty exported path. Can only export a property of a class.");
} }
c10::ClassType* classType = getClassAtPath( c10::ClassType *classType =
&rootClassType, c10::ArrayRef<std::string>(exportedPath) getClassAtPath(&rootClassType, c10::ArrayRef<std::string>(exportedPath)
.slice(0, exportedPath.size() - 1) .slice(0, exportedPath.size() - 1)
.vec()); .vec());
if (!classType->findAttribute(exportedPath.back()) && if (!classType->findAttribute(exportedPath.back()) &&
!classType->findMethod(exportedPath.back())) { !classType->findMethod(exportedPath.back())) {
@ -110,10 +108,10 @@ void ClassAnnotator::exportPath(
<< exportedPath.back() << "'"; << exportedPath.back() << "'";
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
std::vector<AttributeAnnotation>& attributeAnnotations = std::vector<AttributeAnnotation> &attributeAnnotations =
classAnnotation.getAttributeAnnotations(); classAnnotation.getAttributeAnnotations();
const std::vector<c10::ClassAttribute>& classAttributes = const std::vector<c10::ClassAttribute> &classAttributes =
classType->getAttributes(); classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
if (classAttributes[i].getName() == exportedPath.back()) { if (classAttributes[i].getName() == exportedPath.back()) {
@ -121,9 +119,9 @@ void ClassAnnotator::exportPath(
} }
} }
std::vector<MethodAnnotation>& methodAnnotations = std::vector<MethodAnnotation> &methodAnnotations =
classAnnotation.getMethodAnnotations(); classAnnotation.getMethodAnnotations();
const std::vector<torch::jit::Function*>& methods = classType->methods(); const std::vector<torch::jit::Function *> &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
if (methods[i]->name() == exportedPath.back()) { if (methods[i]->name() == exportedPath.back()) {
methodAnnotations[i].isExported = true; methodAnnotations[i].isExported = true;
@ -131,12 +129,12 @@ void ClassAnnotator::exportPath(
} }
} }
const ClassAnnotationMap& ClassAnnotator::getAnnotationMap() { const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() {
return classAnnotations; return classAnnotations;
} }
ClassAnnotation& ClassAnnotation &
ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) { ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
auto className = classType->name()->qualifiedName(); auto className = classType->name()->qualifiedName();
auto it = classAnnotations.find(className); auto it = classAnnotations.find(className);
if (it == classAnnotations.end()) { if (it == classAnnotations.end()) {
@ -151,39 +149,39 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) {
return *it->second; return *it->second;
} }
static void fillArgAnnotations( static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
MethodAnnotation& methodAnnotation, std::vector<ArgAnnotation> argAnnotations,
std::vector<ArgAnnotation> argAnnotations, torch::jit::Function* function) { torch::jit::Function *function) {
if (argAnnotations.size() != function->num_inputs()) { if (argAnnotations.size() != function->num_inputs()) {
throw std::invalid_argument("Arg annotations should have one entry per " throw std::invalid_argument("Arg annotations should have one entry per "
"function parameter (including self)."); "function parameter (including self).");
} }
if (!methodAnnotation.argAnnotations.has_value()) { if (!methodAnnotation.argAnnotations.has_value()) {
methodAnnotation.argAnnotations.emplace( methodAnnotation.argAnnotations.emplace(function->num_inputs(),
function->num_inputs(), ArgAnnotation{}); ArgAnnotation{});
} }
methodAnnotation.argAnnotations = argAnnotations; methodAnnotation.argAnnotations = argAnnotations;
} }
void ClassAnnotator::annotateArgs( void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
c10::ClassType& rootClassType, std::vector<std::string> path, std::vector<std::string> path,
std::vector<ArgAnnotation> argAnnotations) { std::vector<ArgAnnotation> argAnnotations) {
if (path.size() == 0) { if (path.size() == 0) {
throw std::invalid_argument("Empty annotated path. Can only annotate " throw std::invalid_argument("Empty annotated path. Can only annotate "
"shapes/dtypes of a method of a class."); "shapes/dtypes of a method of a class.");
} }
c10::ClassType* classType = getClassAtPath( c10::ClassType *classType = getClassAtPath(
&rootClassType, &rootClassType,
c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec()); c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec());
// Throw error if no method on the class of the specified name. // Throw error if no method on the class of the specified name.
torch::jit::Function* function = &classType->getMethod(path.back()); torch::jit::Function *function = &classType->getMethod(path.back());
ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType); ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
std::vector<MethodAnnotation>& methodAnnotations = std::vector<MethodAnnotation> &methodAnnotations =
classAnnotation.getMethodAnnotations(); classAnnotation.getMethodAnnotations();
const std::vector<torch::jit::Function*>& methods = classType->methods(); const std::vector<torch::jit::Function *> &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
if (methods[i]->name() == path.back()) { if (methods[i]->name() == path.back()) {
fillArgAnnotations(methodAnnotations[i], argAnnotations, function); fillArgAnnotations(methodAnnotations[i], argAnnotations, function);
@ -193,9 +191,9 @@ void ClassAnnotator::annotateArgs(
return; return;
} }
c10::ClassType* ClassAnnotator::getClassAtPath( c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType,
c10::ClassType* rootClassType, std::vector<std::string> path) { std::vector<std::string> path) {
c10::ClassType* classType = rootClassType; c10::ClassType *classType = rootClassType;
// Reverse so that pop_back gives us the initial atoms first. // Reverse so that pop_back gives us the initial atoms first.
std::reverse(path.begin(), path.end()); std::reverse(path.begin(), path.end());
while (!path.empty()) { while (!path.empty()) {
@ -217,8 +215,8 @@ c10::ClassType* ClassAnnotator::getClassAtPath(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Helper methods // Helper methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
MethodAnnotation* MethodAnnotation *
ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) { ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) {
auto it = functionToMethodMap.find(function); auto it = functionToMethodMap.find(function);
if (it == functionToMethodMap.end()) { if (it == functionToMethodMap.end()) {
return nullptr; return nullptr;
@ -230,7 +228,7 @@ ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) {
// toString methods // toString methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
std::string AttributeAnnotation::toString(const std::string& name) { std::string AttributeAnnotation::toString(const std::string &name) {
std::stringstream ss; std::stringstream ss;
ss << "AttributeAnnotation('" << name << "') {\n"; ss << "AttributeAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n";
@ -261,7 +259,7 @@ std::string ArgAnnotation::toString(int argIndex) {
return ss.str(); return ss.str();
} }
std::string MethodAnnotation::toString(const std::string& name) { std::string MethodAnnotation::toString(const std::string &name) {
std::stringstream ss; std::stringstream ss;
ss << "MethodAnnotation('" << name << "') {\n"; ss << "MethodAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n";
@ -282,13 +280,13 @@ std::string ClassAnnotation::toString() {
std::stringstream ss; std::stringstream ss;
ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n"; ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n";
const std::vector<c10::ClassAttribute>& classAttributes = const std::vector<c10::ClassAttribute> &classAttributes =
classType->getAttributes(); classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
ss << indentString( ss << indentString(
" ", attributeAnnotations[i].toString(classAttributes[i].getName())); " ", attributeAnnotations[i].toString(classAttributes[i].getName()));
} }
const std::vector<torch::jit::Function*>& methods = classType->methods(); const std::vector<torch::jit::Function *> &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name())); ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name()));
} }
@ -299,7 +297,7 @@ std::string ClassAnnotation::toString() {
std::string ClassAnnotator::toString() { std::string ClassAnnotator::toString() {
std::stringstream ss; std::stringstream ss;
ss << "ClassAnnotator {\n"; ss << "ClassAnnotator {\n";
for (auto& p : classAnnotations) { for (auto &p : classAnnotations) {
ss << indentString(" ", p.second->toString()); ss << indentString(" ", p.second->toString());
} }
ss << "}\n"; ss << "}\n";

View File

@ -34,7 +34,7 @@ struct AttributeAnnotation {
// can be externally accessed. // can be externally accessed.
bool isExported = true; bool isExported = true;
std::string toString(const std::string& name); std::string toString(const std::string &name);
}; };
// An annotation of an argument of a method. // An annotation of an argument of a method.
@ -80,7 +80,7 @@ struct MethodAnnotation {
// large printout of the default ArgAnnotation for every method. // large printout of the default ArgAnnotation for every method.
c10::optional<std::vector<ArgAnnotation>> argAnnotations; c10::optional<std::vector<ArgAnnotation>> argAnnotations;
std::string toString(const std::string& name); std::string toString(const std::string &name);
}; };
// Annotations on a c10::ClassType. // Annotations on a c10::ClassType.
@ -107,10 +107,10 @@ public:
// Get the attribute annotations. // Get the attribute annotations.
// The length and order is the same as `classType->getAttributes()`. // The length and order is the same as `classType->getAttributes()`.
std::vector<AttributeAnnotation>& getAttributeAnnotations(); std::vector<AttributeAnnotation> &getAttributeAnnotations();
// Get the method annotations. // Get the method annotations.
// The length and order is the same as `classType->methods()`. // The length and order is the same as `classType->methods()`.
std::vector<MethodAnnotation>& getMethodAnnotations(); std::vector<MethodAnnotation> &getMethodAnnotations();
std::string toString(); std::string toString();
@ -141,14 +141,14 @@ public:
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should // For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
// have a submodule `a` and that submodule should have a method or attribute // have a submodule `a` and that submodule should have a method or attribute
// `b`. // `b`.
void exportPath( void exportPath(c10::ClassType &rootClassType,
c10::ClassType& rootClassType, std::vector<std::string> exportedPath); std::vector<std::string> exportedPath);
// Mark everything as not-exported. // Mark everything as not-exported.
// //
// This is kind of useless by itself, but together with `exportPath` allows // This is kind of useless by itself, but together with `exportPath` allows
// exporting a subset of known names out of a larger collection of unknown // exporting a subset of known names out of a larger collection of unknown
// names. // names.
void exportNone(c10::ClassType& rootClassType); void exportNone(c10::ClassType &rootClassType);
// Annotate shapes and dtypes of the arguments of a method at path `path` from // Annotate shapes and dtypes of the arguments of a method at path `path` from
// `rootClassType`. // `rootClassType`.
@ -159,23 +159,23 @@ public:
// a "has value semantics" boolean. // a "has value semantics" boolean.
// These will be put into an `ArgAnnotation` struct -- see there for // These will be put into an `ArgAnnotation` struct -- see there for
// precise definitions of the promised semantics of each entry. // precise definitions of the promised semantics of each entry.
void annotateArgs( void annotateArgs(c10::ClassType &rootClassType,
c10::ClassType& rootClassType, std::vector<std::string> path, std::vector<std::string> path,
std::vector<ArgAnnotation> argAnnotations); std::vector<ArgAnnotation> argAnnotations);
// The annotations collected so far. // The annotations collected so far.
const ClassAnnotationMap& getAnnotationMap(); const ClassAnnotationMap &getAnnotationMap();
// Get the ClassAnnotation corresponding to `classType`. // Get the ClassAnnotation corresponding to `classType`.
ClassAnnotation& getOrCreateClassAnnotation(c10::ClassType* classType); ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType);
// Helper to find the MethodAnnotation corresponding to a // Helper to find the MethodAnnotation corresponding to a
// torch::jit::Function, or null if not found. // torch::jit::Function, or null if not found.
// //
// Users could in principle scan all annotations to find this, but it's more // Users could in principle scan all annotations to find this, but it's more
// efficient to maintain the reverse mapping directly. // efficient to maintain the reverse mapping directly.
MethodAnnotation* MethodAnnotation *
getMethodAnnotationForFunction(torch::jit::Function* function); getMethodAnnotationForFunction(torch::jit::Function *function);
std::string toString(); std::string toString();
@ -183,11 +183,11 @@ private:
// Traverse `path` starting from `rootClassType` to find the ClassType // Traverse `path` starting from `rootClassType` to find the ClassType
// of a presumed nested submodule. Throw an error if there is no such // of a presumed nested submodule. Throw an error if there is no such
// submodule. // submodule.
c10::ClassType* c10::ClassType *getClassAtPath(c10::ClassType *rootClassType,
getClassAtPath(c10::ClassType* rootClassType, std::vector<std::string> path); std::vector<std::string> path);
ClassAnnotationMap classAnnotations; ClassAnnotationMap classAnnotations;
// Reverse mapping used to service getMethodAnnotationForFunction. // Reverse mapping used to service getMethodAnnotationForFunction.
std::unordered_map<torch::jit::Function*, MethodAnnotation*> std::unordered_map<torch::jit::Function *, MethodAnnotation *>
functionToMethodMap; functionToMethodMap;
}; };

View File

@ -21,9 +21,9 @@
using namespace torch_mlir; using namespace torch_mlir;
MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function* function, MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute, std::function<MlirAttribute(int)> getArgAttribute,
const ImportOptions& importOptions) { const ImportOptions &importOptions) {
// Useful for debugging: // Useful for debugging:
// graph->dump(); // graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
@ -63,11 +63,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
} }
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues, auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd( createMlirOperationAtEnd(appendToBlock, "func.return", loc,
appendToBlock, "func.return", loc, adjustStaticInformationForValues(
adjustStaticInformationForValues( appendToBlock, loc, yieldedValues, resultTypes,
appendToBlock, loc, yieldedValues, resultTypes, /*userAllowsRefinement=*/false));
/*userAllowsRefinement=*/false));
}; };
MlirBlock block = importBlock( MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(), context, torch::jit::toGraphFunction(*function).graph()->block(),

View File

@ -40,10 +40,10 @@ namespace torch_mlir {
/// null MlirAttribute is returned, no attribute will be attached to that /// null MlirAttribute is returned, no attribute will be attached to that
/// argument. /// argument.
MlirOperation importJitFunctionAsFuncOp( MlirOperation importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function* function, MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute = std::function<MlirAttribute(int)> getArgAttribute =
[](int) -> MlirAttribute { return {nullptr}; }, [](int) -> MlirAttribute { return {nullptr}; },
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -49,10 +49,10 @@ using namespace torch_mlir;
// throw an error on). // throw an error on).
namespace { namespace {
struct IValueHasher { struct IValueHasher {
size_t operator()(const c10::IValue& ivalue) const { size_t operator()(const c10::IValue &ivalue) const {
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
return std::hash<const void*>()( return std::hash<const void *>()(
static_cast<const void*>(ivalue.internalToPointer())); static_cast<const void *>(ivalue.internalToPointer()));
} }
return c10::IValue::hash(ivalue); return c10::IValue::hash(ivalue);
@ -65,7 +65,7 @@ struct IValueHasher {
// such as when tracing). Can we do better? // such as when tracing). Can we do better?
namespace { namespace {
struct IValueEq { struct IValueEq {
bool operator()(const c10::IValue& lhs, const c10::IValue& rhs) const { bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const {
return lhs.isSameIdentity(rhs); return lhs.isSameIdentity(rhs);
} }
}; };
@ -99,9 +99,8 @@ namespace {
/// (PyTorch allows this!). /// (PyTorch allows this!).
class IValueImporter { class IValueImporter {
public: public:
IValueImporter( IValueImporter(MlirBlock importBlock, MlirContext context,
MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator, ClassAnnotator &annotator, const ImportOptions &importOptions)
const ImportOptions& importOptions)
: importBlock(importBlock), context(context), annotator(annotator), : importBlock(importBlock), context(context), annotator(annotator),
importOptions(importOptions) {} importOptions(importOptions) {}
@ -111,16 +110,15 @@ private:
MlirValue rawImportIValue(c10::IValue ivalue); MlirValue rawImportIValue(c10::IValue ivalue);
MlirValue importTensor(c10::IValue ivalue); MlirValue importTensor(c10::IValue ivalue);
MlirValue importModule(torch::jit::Module jitModule); MlirValue importModule(torch::jit::Module jitModule);
void importMethod( void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
torch::jit::Function* function, MlirBlock classTypeBody, const MethodAnnotation &methodAnnotation);
const MethodAnnotation& methodAnnotation); void importClassType(c10::ClassType *classType);
void importClassType(c10::ClassType* classType); void importCompilationUnit(torch::jit::CompilationUnit *cu);
void importCompilationUnit(torch::jit::CompilationUnit* cu);
MlirBlock importBlock; MlirBlock importBlock;
MlirContext context; MlirContext context;
ClassAnnotator& annotator; ClassAnnotator &annotator;
const ImportOptions& importOptions; const ImportOptions &importOptions;
// Map tracking already-imported values. // Map tracking already-imported values.
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap; std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
@ -131,16 +129,16 @@ private:
// e.g. methods (the function names are meaningful and match with Python's // e.g. methods (the function names are meaningful and match with Python's
// module hierarchy, with the exception of `__main__` being replaced with // module hierarchy, with the exception of `__main__` being replaced with
// `__torch__`). // `__torch__`).
torch::jit::CompilationUnit* compilationUnit = nullptr; torch::jit::CompilationUnit *compilationUnit = nullptr;
// Used to detect potentially aliasing tensors. // Used to detect potentially aliasing tensors.
std::unordered_set<c10::StorageImpl*> seenStorageImpls; std::unordered_set<c10::StorageImpl *> seenStorageImpls;
// The set of ClassType's that have already been imported. // The set of ClassType's that have already been imported.
// //
// ClassType's are referenced via their `classType->name()->qualifiedName()` // ClassType's are referenced via their `classType->name()->qualifiedName()`
// string (as an MLIR symbol name) so we don't need to keep a map associating // string (as an MLIR symbol name) so we don't need to keep a map associating
// them with the MlirOperation that they import into. // them with the MlirOperation that they import into.
std::unordered_set<c10::ClassType*> classTypes; std::unordered_set<c10::ClassType *> classTypes;
// The stack of attribute names we have traversed to reach the current IValue. // The stack of attribute names we have traversed to reach the current IValue.
// Used for diagnostics. // Used for diagnostics.
std::vector<std::string> attributeNameStack; std::vector<std::string> attributeNameStack;
@ -192,8 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate()); mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(nnModuleRegion,
nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); mlirBlockCreate(0, nullptr, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
InserterGuard inserterGuard(importBlock, nnModule); InserterGuard inserterGuard(importBlock, nnModule);
@ -201,14 +199,13 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
rootModuleName = moduleTypeName; rootModuleName = moduleTypeName;
} }
const std::vector<c10::IValue>& slots = currentModule._ivalue()->slots(); const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
const std::vector<c10::ClassAttribute>& classAttributes = const std::vector<c10::ClassAttribute> &classAttributes =
currentModule.type()->getAttributes(); currentModule.type()->getAttributes();
assert( assert(slots.size() == classAttributes.size() &&
slots.size() == classAttributes.size() && "mismatch between object and type!");
"mismatch between object and type!");
for (int i = 0, e = slots.size(); i < e; i++) { for (int i = 0, e = slots.size(); i < e; i++) {
const c10::ClassAttribute& classAttribute = classAttributes[i]; const c10::ClassAttribute &classAttribute = classAttributes[i];
attributeNameStack.push_back(classAttribute.getName()); attributeNameStack.push_back(classAttribute.getName());
MlirValue slotValue = importIValue(slots[i]); MlirValue slotValue = importIValue(slots[i]);
// TODO: Is it necessary to track whether an attribute is a "parameter"? // TODO: Is it necessary to track whether an attribute is a "parameter"?
@ -235,7 +232,7 @@ MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
} }
// Reject potentially aliased tensors. // Reject potentially aliased tensors.
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
c10::StorageImpl* storageImpl = c10::StorageImpl *storageImpl =
ivalue.toTensor().storage().unsafeGetStorageImpl(); ivalue.toTensor().storage().unsafeGetStorageImpl();
if (!seenStorageImpls.insert(storageImpl).second) { if (!seenStorageImpls.insert(storageImpl).second) {
std::stringstream msg; std::stringstream msg;
@ -261,8 +258,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirType type = torchMlirTorchBoolTypeGet(context); MlirType type = torchMlirTorchBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.bool", loc, type, importBlock, "torch.constant.bool", loc, type,
toMlirNamedAttribute( toMlirNamedAttribute("value",
"value", mlirBoolAttrGet(context, ivalue.toBool()))); mlirBoolAttrGet(context, ivalue.toBool())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isDouble()) { if (ivalue.isDouble()) {
@ -270,23 +267,23 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.float", loc, type, importBlock, "torch.constant.float", loc, type,
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirFloatAttrDoubleGet( "value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
context, mlirF64TypeGet(context), ivalue.toDouble()))); ivalue.toDouble())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isInt()) { if (ivalue.isInt()) {
MlirType type = torchMlirTorchIntTypeGet(context); MlirType type = torchMlirTorchIntTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.int", loc, type, importBlock, "torch.constant.int", loc, type,
toMlirNamedAttribute( toMlirNamedAttribute("value",
"value", mlirIntegerAttrGet( mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64),
mlirIntegerTypeGet(context, 64), ivalue.toInt()))); ivalue.toInt())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isList()) { if (ivalue.isList()) {
c10::List<c10::IValue> list = ivalue.toList(); c10::List<c10::IValue> list = ivalue.toList();
std::vector<MlirValue> elems; std::vector<MlirValue> elems;
for (const c10::IValue& elem : list) { for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem)); elems.push_back(importIValue(elem));
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
@ -316,7 +313,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
auto list = ivalue.toTuple()->elements(); auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> operands; std::vector<MlirValue> operands;
std::vector<MlirType> types; std::vector<MlirType> types;
for (const c10::IValue& elem : list) { for (const c10::IValue &elem : list) {
MlirValue operand = importIValue(elem); MlirValue operand = importIValue(elem);
operands.push_back(operand); operands.push_back(operand);
types.push_back(mlirValueGetType(operand)); types.push_back(mlirValueGetType(operand));
@ -339,14 +336,14 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
torchMlirTorchStringTypeGet(context), torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirStringAttrGet( mlirStringAttrGet(context,
context, toMlirStringRef(ivalue.toString()->string())))); toMlirStringRef(ivalue.toString()->string()))));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isNone()) { if (ivalue.isNone()) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation =
importBlock, "torch.constant.none", loc, createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
torchMlirTorchNoneTypeGet(context)); torchMlirTorchNoneTypeGet(context));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isCustomClass()) { if (ivalue.isCustomClass()) {
@ -440,12 +437,12 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
return tensorValue; return tensorValue;
} }
void IValueImporter::importMethod( void IValueImporter::importMethod(torch::jit::Function *function,
torch::jit::Function* function, MlirBlock classTypeBody, MlirBlock classTypeBody,
const MethodAnnotation& methodAnnotation) { const MethodAnnotation &methodAnnotation) {
// The function's name becomes the MLIR symbol table name of the imported func // The function's name becomes the MLIR symbol table name of the imported func
// when we import the compilation unit. // when we import the compilation unit.
const std::string& symName = function->qualname().qualifiedName(); const std::string &symName = function->qualname().qualifiedName();
MlirAttribute functionSymbolRef = MlirAttribute functionSymbolRef =
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)); mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
@ -461,7 +458,7 @@ void IValueImporter::importMethod(
toMlirNamedAttribute("function", functionSymbolRef), isPrivate); toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
} }
void IValueImporter::importClassType(c10::ClassType* classType) { void IValueImporter::importClassType(c10::ClassType *classType) {
if (!classTypes.insert(classType).second) { if (!classTypes.insert(classType).second) {
return; return;
} }
@ -479,13 +476,13 @@ void IValueImporter::importClassType(c10::ClassType* classType) {
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr)); mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
ClassAnnotation& classAnnotation = ClassAnnotation &classAnnotation =
annotator.getOrCreateClassAnnotation(classType); annotator.getOrCreateClassAnnotation(classType);
const auto& attributeAnnotations = classAnnotation.getAttributeAnnotations(); const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
const auto& classAttributes = classType->getAttributes(); const auto &classAttributes = classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
const c10::ClassAttribute& classAttribute = classAttributes[i]; const c10::ClassAttribute &classAttribute = classAttributes[i];
c10::optional<MlirNamedAttribute> isPrivate; c10::optional<MlirNamedAttribute> isPrivate;
if (!attributeAnnotations[i].isExported) { if (!attributeAnnotations[i].isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
@ -501,8 +498,8 @@ void IValueImporter::importClassType(c10::ClassType* classType) {
isPrivate); isPrivate);
} }
const auto& methodAnnotations = classAnnotation.getMethodAnnotations(); const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
const auto& methods = classType->methods(); const auto &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
importMethod(methods[i], classTypeBody, methodAnnotations[i]); importMethod(methods[i], classTypeBody, methodAnnotations[i]);
} }
@ -510,7 +507,7 @@ void IValueImporter::importClassType(c10::ClassType* classType) {
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
} }
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) { void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
if (compilationUnit == nullptr) { if (compilationUnit == nullptr) {
compilationUnit = cu; compilationUnit = cu;
} else { } else {
@ -529,14 +526,14 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
return; return;
} }
for (torch::jit::Function* function : cu->get_functions()) { for (torch::jit::Function *function : cu->get_functions()) {
// Useful for debugging errors in free functions that end up being // Useful for debugging errors in free functions that end up being
// unused. These can be missing when round-tripping through the on-disk // unused. These can be missing when round-tripping through the on-disk
// format, even though they still cause import issues when importing // format, even though they still cause import issues when importing
// through the larger Python session where they originate. // through the larger Python session where they originate.
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n"; // std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
// std::cerr << *torch::jit::toGraphFunction(function).graph(); // std::cerr << *torch::jit::toGraphFunction(function).graph();
MethodAnnotation* annotation = MethodAnnotation *annotation =
annotator.getMethodAnnotationForFunction(function); annotator.getMethodAnnotationForFunction(function);
MlirOperation func = importJitFunctionAsFuncOp( MlirOperation func = importJitFunctionAsFuncOp(
context, function, context, function,
@ -544,9 +541,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
if (!annotation || !annotation->argAnnotations.has_value()) { if (!annotation || !annotation->argAnnotations.has_value()) {
return {nullptr}; return {nullptr};
} }
c10::optional<std::vector<int64_t>>& maybeShape = c10::optional<std::vector<int64_t>> &maybeShape =
annotation->argAnnotations.value()[argIndex].shape; annotation->argAnnotations.value()[argIndex].shape;
c10::optional<c10::ScalarType>& maybeDtype = c10::optional<c10::ScalarType> &maybeDtype =
annotation->argAnnotations.value()[argIndex].dtype; annotation->argAnnotations.value()[argIndex].dtype;
bool hasValueSemantics = bool hasValueSemantics =
annotation->argAnnotations.value()[argIndex].hasValueSemantics; annotation->argAnnotations.value()[argIndex].hasValueSemantics;
@ -566,10 +563,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
// the C API constructor, when we want the "we know we have 0 sizes" // the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer. // case. So use a dummy data pointer.
int64_t dummy; int64_t dummy;
int64_t* shapeData = shape.size() == 0 ? &dummy : shape.data(); int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
if (hasValueSemantics) { if (hasValueSemantics) {
typeBound = torchMlirTorchValueTensorTypeGet( typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
context, shape.size(), shapeData, dtype); shapeData, dtype);
} else { } else {
typeBound = torchMlirTorchNonValueTensorTypeGet( typeBound = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shapeData, dtype); context, shape.size(), shapeData, dtype);
@ -597,9 +594,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
} }
} }
MlirValue torch_mlir::importIValue( MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
c10::IValue ivalue, MlirBlock block, MlirContext context, MlirContext context,
ClassAnnotator& annotator, const ImportOptions& importOptions) { ClassAnnotator &annotator,
const ImportOptions &importOptions) {
// When debugging module importing, it can be useful to dump as so: // When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule()) // if (ivalue.isModule())
// ivalue.toModule().dump(true, false, false); // ivalue.toModule().dump(true, false, false);

View File

@ -25,9 +25,9 @@ namespace torch_mlir {
/// Main entry-point for importing torch IValue's . /// Main entry-point for importing torch IValue's .
/// Recursively imports `ivalue`, inserting operations at the end of `block`. /// Recursively imports `ivalue`, inserting operations at the end of `block`.
MlirValue importIValue( MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
c10::IValue ivalue, MlirBlock block, MlirContext context, ClassAnnotator &annotator,
ClassAnnotator& annotator, const ImportOptions& importOptions); const ImportOptions &importOptions);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -22,92 +22,92 @@
namespace torch_mlir { namespace torch_mlir {
inline MlirStringRef toMlirStringRef(const std::string& s) { inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size()); return mlirStringRefCreate(s.data(), s.size());
} }
inline MlirStringRef toMlirStringRef(const char* s) { inline MlirStringRef toMlirStringRef(const char *s) {
return mlirStringRefCreate(s, std::strlen(s)); return mlirStringRefCreate(s, std::strlen(s));
} }
inline MlirNamedAttribute inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
toMlirNamedAttribute(const char* s, MlirAttribute attr) { MlirAttribute attr) {
MlirContext context = mlirAttributeGetContext(attr); MlirContext context = mlirAttributeGetContext(attr);
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
return mlirNamedAttributeGet(ident, attr); return mlirNamedAttributeGet(ident, attr);
} }
inline void addToMlirOperationState( inline void addToMlirOperationState(MlirOperationState &state,
MlirOperationState& state, MlirNamedAttribute namedAttr) { MlirNamedAttribute namedAttr) {
mlirOperationStateAddAttributes(&state, 1, &namedAttr); mlirOperationStateAddAttributes(&state, 1, &namedAttr);
} }
inline void inline void addToMlirOperationState(MlirOperationState &state,
addToMlirOperationState(MlirOperationState& state, MlirRegion region) { MlirRegion region) {
mlirOperationStateAddOwnedRegions(&state, 1, &region); mlirOperationStateAddOwnedRegions(&state, 1, &region);
} }
inline void inline void addToMlirOperationState(MlirOperationState &state,
addToMlirOperationState(MlirOperationState& state, MlirValue value) { MlirValue value) {
mlirOperationStateAddOperands(&state, 1, &value); mlirOperationStateAddOperands(&state, 1, &value);
} }
inline void addToMlirOperationState( inline void addToMlirOperationState(MlirOperationState &state,
MlirOperationState& state, const std::vector<MlirValue>& values) { const std::vector<MlirValue> &values) {
mlirOperationStateAddOperands(&state, values.size(), values.data()); mlirOperationStateAddOperands(&state, values.size(), values.data());
} }
inline void addToMlirOperationState( inline void addToMlirOperationState(MlirOperationState &state,
MlirOperationState& state, c10::ArrayRef<MlirValue> values) { c10::ArrayRef<MlirValue> values) {
mlirOperationStateAddOperands(&state, values.size(), values.data()); mlirOperationStateAddOperands(&state, values.size(), values.data());
} }
inline void inline void addToMlirOperationState(MlirOperationState &state,
addToMlirOperationState(MlirOperationState& state, MlirType resultType) { MlirType resultType) {
mlirOperationStateAddResults(&state, 1, &resultType); mlirOperationStateAddResults(&state, 1, &resultType);
} }
inline void addToMlirOperationState( inline void addToMlirOperationState(MlirOperationState &state,
MlirOperationState& state, const std::vector<MlirType>& resultTypes) { const std::vector<MlirType> &resultTypes) {
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
} }
inline void addToMlirOperationState( inline void addToMlirOperationState(MlirOperationState &state,
MlirOperationState& state, c10::ArrayRef<MlirType> resultTypes) { c10::ArrayRef<MlirType> resultTypes) {
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
} }
template <typename T> template <typename T>
void addToMlirOperationState(MlirOperationState& state, c10::optional<T> o) { void addToMlirOperationState(MlirOperationState &state, c10::optional<T> o) {
if (o.has_value()) { if (o.has_value()) {
addToMlirOperationState(state, o.value()); addToMlirOperationState(state, o.value());
} }
} }
inline void addToMlirOperationState(MlirOperationState& state) {} inline void addToMlirOperationState(MlirOperationState &state) {}
template <typename T, typename U, typename... Ts> template <typename T, typename U, typename... Ts>
void addToMlirOperationState( void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u,
MlirOperationState& state, T&& t, U&& u, Ts&&... ts) { Ts &&...ts) {
addToMlirOperationState(state, std::forward<T>(t)); addToMlirOperationState(state, std::forward<T>(t));
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...); addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
} }
template <typename... Ts> template <typename... Ts>
MlirOperation MlirOperation createMlirOperation(std::string name, MlirLocation loc,
createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) { Ts &&...ts) {
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
addToMlirOperationState(state, std::forward<Ts>(ts)...); addToMlirOperationState(state, std::forward<Ts>(ts)...);
return mlirOperationCreate(&state); return mlirOperationCreate(&state);
} }
template <typename... Ts> template <typename... Ts>
MlirOperation createMlirOperationAtEnd( MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) { MlirLocation loc, Ts &&...ts) {
MlirOperation operation = MlirOperation operation =
createMlirOperation(name, loc, std::forward<Ts>(ts)...); createMlirOperation(name, loc, std::forward<Ts>(ts)...);
mlirBlockInsertOwnedOperationBefore( mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
block, mlirBlockGetTerminator(block), operation); operation);
return operation; return operation;
} }

View File

@ -33,42 +33,40 @@ class NodeImporter {
public: public:
NodeImporter(MlirContext context) : context(context) {} NodeImporter(MlirContext context) : context(context) {}
void importNode( void importNode(Node *node, MlirBlock appendToBlock,
Node* node, MlirBlock appendToBlock, const ImportOptions &importOptions = {});
const ImportOptions& importOptions = {});
MlirBlock importBlock( MlirBlock importBlock(
Block* jitBlock, CreateTerminatorFn createTerminator, Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
private: private:
MlirBlock createBlockFor( MlirBlock createBlockFor(Block *jitBlock,
Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
void mapValue(Value* jitValue, MlirValue value); void mapValue(Value *jitValue, MlirValue value);
void mapResults(Node* node, MlirOperation operation); void mapResults(Node *node, MlirOperation operation);
MlirValue lookupMappedValue(Value* jitValue); MlirValue lookupMappedValue(Value *jitValue);
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value*> values); std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
MlirContext context; MlirContext context;
std::unordered_map<Value*, MlirValue> valueMap; std::unordered_map<Value *, MlirValue> valueMap;
}; };
} // namespace } // namespace
using InputsTransformFn = using InputsTransformFn =
std::function<std::vector<MlirValue>(std::vector<MlirValue>&)>; std::function<std::vector<MlirValue>(std::vector<MlirValue> &)>;
// The inputs of `DictConstruct` in TorchScript IR are in the order // The inputs of `DictConstruct` in TorchScript IR are in the order
// like k0, v0, k1, v1. Rearrange them to put the key operands together and // like k0, v0, k1, v1. Rearrange them to put the key operands together and
// then the value operands like k0, k1,v0, v1. This is the expected format by // then the value operands like k0, k1,v0, v1. This is the expected format by
// the corresponding MLIR op. // the corresponding MLIR op.
static std::vector<MlirValue> static std::vector<MlirValue>
rearrangeDictConstructInputs(std::vector<MlirValue>& inputs) { rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
if (inputs.empty()) if (inputs.empty())
return inputs; return inputs;
assert( assert(inputs.size() % 2 == 0 &&
inputs.size() % 2 == 0 && "DictConstruct must have even number of operands");
"DictConstruct must have even number of operands");
std::vector<MlirValue> rearranged; std::vector<MlirValue> rearranged;
std::vector<MlirValue> values; std::vector<MlirValue> values;
@ -80,12 +78,12 @@ rearrangeDictConstructInputs(std::vector<MlirValue>& inputs) {
return rearranged; return rearranged;
} }
void NodeImporter::importNode( void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) { const ImportOptions &importOptions) {
MlirLocation loc = getMlirLocationFromNode(context, node); MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind(); auto kind = node->kind();
auto createAndMapTrivialNode = [&](Node* node, const std::string& opName, auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
InputsTransformFn t) { InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs()); std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
@ -96,7 +94,7 @@ void NodeImporter::importNode(
}; };
auto createAndMapNodeWithAttribute = auto createAndMapNodeWithAttribute =
[&](Node* node, const std::string& opName, const std::string& attrName, [&](Node *node, const std::string &opName, const std::string &attrName,
MlirAttribute attr) { MlirAttribute attr) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc, appendToBlock, opName, loc,
@ -133,27 +131,27 @@ void NodeImporter::importNode(
// ListConstruct and DictConstruct too. // ListConstruct and DictConstruct too.
auto containedTypes = c10::fmap( auto containedTypes = c10::fmap(
node->output()->type()->cast<c10::TupleType>()->containedTypes(), node->output()->type()->cast<c10::TupleType>()->containedTypes(),
[&](const c10::TypePtr& t) { [&](const c10::TypePtr &t) {
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
return type; return type;
}); });
createAndMapTrivialNode( createAndMapTrivialNode(node,
node, "torch.prim." + std::string(kind.toUnqualString()), "torch.prim." + std::string(kind.toUnqualString()),
[&](std::vector<MlirValue>& inputs) { [&](std::vector<MlirValue> &inputs) {
assert(containedTypes.size() == inputs.size()); assert(containedTypes.size() == inputs.size());
return adjustStaticInformationForValues( return adjustStaticInformationForValues(
appendToBlock, loc, inputs, containedTypes, appendToBlock, loc, inputs, containedTypes,
/*userAllowsRefinement=*/true); /*userAllowsRefinement=*/true);
}); });
return; return;
} }
case c10::prim::DictConstruct: { case c10::prim::DictConstruct: {
createAndMapTrivialNode( createAndMapTrivialNode(node,
node, "torch.prim." + std::string(kind.toUnqualString()), "torch.prim." + std::string(kind.toUnqualString()),
rearrangeDictConstructInputs); rearrangeDictConstructInputs);
return; return;
} }
case c10::prim::Load: case c10::prim::Load:
@ -171,34 +169,32 @@ void NodeImporter::importNode(
auto output = node->output(); auto output = node->output();
MlirOperation op; MlirOperation op;
if (output->type()->cast<c10::NoneType>()) { if (output->type()->cast<c10::NoneType>()) {
op = createMlirOperation( op = createMlirOperation("torch.constant.none", loc,
"torch.constant.none", loc, torchMlirTorchNoneTypeGet(context)); torchMlirTorchNoneTypeGet(context));
} else if (output->type()->cast<c10::BoolType>()) { } else if (output->type()->cast<c10::BoolType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context), "torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
mlirBoolAttrGet( c10::attr::value)))));
context, static_cast<bool>(node->i(c10::attr::value)))));
} else if (output->type()->cast<c10::IntType>()) { } else if (output->type()->cast<c10::IntType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.int", loc, "torch.constant.int", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute("value",
"value", importAttribute(loc, node, c10::attr::value))); importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::FloatType>()) { } else if (output->type()->cast<c10::FloatType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.float", loc, "torch.constant.float", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute("value",
"value", importAttribute(loc, node, c10::attr::value))); importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) { } else if (output->type()->cast<c10::StringType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context), "torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value", mlirStringAttrGet(context, toMlirStringRef(node->s(
mlirStringAttrGet( c10::attr::value)))));
context, toMlirStringRef(node->s(c10::attr::value)))));
} else if (output->type()->cast<c10::TensorType>()) { } else if (output->type()->cast<c10::TensorType>()) {
MlirAttribute attr = importAttribute(loc, node, c10::attr::value); MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
if (importOptions.assumeTensorsHaveValueSemantics) { if (importOptions.assumeTensorsHaveValueSemantics) {
@ -217,26 +213,24 @@ void NodeImporter::importNode(
"torch.constant.device", loc, "torch.constant.device", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value", mlirStringAttrGet(context, toMlirStringRef(node->s(
mlirStringAttrGet( c10::attr::value)))));
context, toMlirStringRef(node->s(c10::attr::value)))));
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) { } else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
torch::jit::Function* function = functionType->function(); torch::jit::Function *function = functionType->function();
const std::string& symName = function->qualname().qualifiedName(); const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation( op = createMlirOperation(
"func.constant", loc, "func.constant", loc,
getFunctionTypeFromSchema( getFunctionTypeFromSchema(context, function->getSchema(),
context, function->getSchema(), importOptions), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else if ( } else if (output->type()->cast<c10::ListType>() ||
output->type()->cast<c10::ListType>() || output->type()->cast<c10::TupleType>()) {
output->type()->cast<c10::TupleType>()) {
ClassAnnotator dummyAnnotator; ClassAnnotator dummyAnnotator;
MlirValue listOrTupleValue = importIValue( MlirValue listOrTupleValue =
node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator, importIValue(node->ival(c10::attr::value), appendToBlock, context,
importOptions); dummyAnnotator, importOptions);
mapResults(node, mlirOpResultGetOwner(listOrTupleValue)); mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
return; // Early return, since `importIValue` already added op to block. return; // Early return, since `importIValue` already added op to block.
} else { } else {
@ -264,20 +258,19 @@ void NodeImporter::importNode(
mapResults(node, operation); mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = { std::vector<MlirType> terminatorOperandTypes = {
torchMlirTorchBoolTypeGet(context)}; torchMlirTorchBoolTypeGet(context)};
terminatorOperandTypes.insert( terminatorOperandTypes.insert(terminatorOperandTypes.end(),
terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end()); resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues, auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd( createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop.condition", loc, appendToBlock, "torch.prim.Loop.condition", loc,
adjustStaticInformationForValues( adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
appendToBlock, loc, yieldedValues, terminatorOperandTypes, terminatorOperandTypes,
/*userAllowsRefinement=*/false)); /*userAllowsRefinement=*/false));
}; };
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
mlirOperationGetRegion(operation, 0), importBlock(node->blocks()[0], createTerminator,
importBlock( c10::nullopt, importOptions));
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
return; return;
} }
@ -292,25 +285,23 @@ void NodeImporter::importNode(
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd( createMlirOperationAtEnd(
appendToBlock, "torch.prim.If.yield", loc, appendToBlock, "torch.prim.If.yield", loc,
adjustStaticInformationForValues( adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
appendToBlock, loc, yieldedValues, resultTypes, resultTypes,
/*userAllowsRefinement=*/false)); /*userAllowsRefinement=*/false));
}; };
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
mlirOperationGetRegion(operation, 0), importBlock(node->blocks()[0], createTerminator,
importBlock( c10::nullopt, importOptions));
node->blocks()[0], createTerminator, c10::nullopt, importOptions)); mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
mlirRegionAppendOwnedBlock( importBlock(node->blocks()[1], createTerminator,
mlirOperationGetRegion(operation, 1), c10::nullopt, importOptions));
importBlock(
node->blocks()[1], createTerminator, c10::nullopt, importOptions));
return; return;
} }
if (kind == c10::prim::CallMethod) { if (kind == c10::prim::CallMethod) {
auto classType = node->input(0)->type()->cast<c10::ClassType>(); auto classType = node->input(0)->type()->cast<c10::ClassType>();
auto methodName = node->s(c10::attr::name); auto methodName = node->s(c10::attr::name);
torch::jit::Function* function = classType->findMethod(methodName); torch::jit::Function *function = classType->findMethod(methodName);
MlirType calleeType = getFunctionTypeFromSchema( MlirType calleeType = getFunctionTypeFromSchema(
context, function->getSchema(), importOptions); context, function->getSchema(), importOptions);
std::vector<MlirType> expectedTypes; std::vector<MlirType> expectedTypes;
@ -323,17 +314,17 @@ void NodeImporter::importNode(
adjustStaticInformationForValues( adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs()), appendToBlock, loc, lookupMappedValues(node->inputs()),
expectedTypes, /*userAllowsRefinement=*/false), expectedTypes, /*userAllowsRefinement=*/false),
toMlirNamedAttribute( toMlirNamedAttribute("name",
"name", importAttribute(loc, node, c10::attr::name))); importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation); mapResults(node, operation);
return; return;
} }
if (kind == c10::prim::CallFunction) { if (kind == c10::prim::CallFunction) {
auto functionType = node->input(0)->type()->cast<c10::FunctionType>(); auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block* calleeEntryBlock = torch::jit::Block *calleeEntryBlock =
torch::jit::toGraphFunction(*functionType->function()).graph()->block(); torch::jit::toGraphFunction(*functionType->function()).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value* v) { auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type(), importOptions); return getMlirTypeFromTorchType(loc, v->type(), importOptions);
}); });
std::string functionName = node->input(0)->node()->s(c10::attr::name); std::string functionName = node->input(0)->node()->s(c10::attr::name);
@ -348,9 +339,9 @@ void NodeImporter::importNode(
// promoted result dtype for a PyTorch computation. Here we turn the call to // promoted result dtype for a PyTorch computation. Here we turn the call to
// this function to the torch dialect equivalent op `torch.promote_dtypes`. // this function to the torch dialect equivalent op `torch.promote_dtypes`.
if (functionName == "__torch_mlir_internal_promote_dtypes") { if (functionName == "__torch_mlir_internal_promote_dtypes") {
operation = createMlirOperationAtEnd( operation =
appendToBlock, "torch.promote_dtypes", loc, resultTypes, createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc,
adjustedFuncArgs); resultTypes, adjustedFuncArgs);
} else { } else {
operation = createMlirOperationAtEnd( operation = createMlirOperationAtEnd(
appendToBlock, "func.call_indirect", loc, resultTypes, appendToBlock, "func.call_indirect", loc, resultTypes,
@ -369,23 +360,23 @@ void NodeImporter::importNode(
} }
} }
MlirBlock NodeImporter::importBlock( MlirBlock
Block* jitBlock, CreateTerminatorFn createTerminator, NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions& importOptions) { const ImportOptions &importOptions) {
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
for (Node* node : jitBlock->nodes()) { for (Node *node : jitBlock->nodes()) {
importNode(node, block, importOptions); importNode(node, block, importOptions);
} }
Node* returnNode = jitBlock->return_node(); Node *returnNode = jitBlock->return_node();
createTerminator(lookupMappedValues(returnNode->inputs()), block); createTerminator(lookupMappedValues(returnNode->inputs()), block);
return block; return block;
} }
MlirBlock NodeImporter::createBlockFor( MlirBlock NodeImporter::createBlockFor(
Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions& importOptions) { const ImportOptions &importOptions) {
Node* paramNode = jitBlock->param_node(); Node *paramNode = jitBlock->param_node();
MlirLocation loc = getMlirLocationFromNode(context, paramNode); MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> paramNodeTypes = std::vector<MlirType> paramNodeTypes =
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
@ -394,11 +385,11 @@ MlirBlock NodeImporter::createBlockFor(
else else
assert(blockArgTypes->size() == paramNodeTypes.size()); assert(blockArgTypes->size() == paramNodeTypes.size());
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc); std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
MlirBlock block = mlirBlockCreate( MlirBlock block =
blockArgTypes.value().size(), blockArgTypes.value().data(), mlirBlockCreate(blockArgTypes.value().size(),
blockArgLocs.data()); blockArgTypes.value().data(), blockArgLocs.data());
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value* jitValue = paramNode->outputs()[i]; Value *jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i); MlirValue value = mlirBlockGetArgument(block, i);
MlirValue adjusted = adjustStaticInformationForValues( MlirValue adjusted = adjustStaticInformationForValues(
block, loc, {value}, {paramNodeTypes[i]}, block, loc, {value}, {paramNodeTypes[i]},
@ -408,40 +399,40 @@ MlirBlock NodeImporter::createBlockFor(
return block; return block;
} }
void NodeImporter::mapValue(Value* jitValue, MlirValue value) { void NodeImporter::mapValue(Value *jitValue, MlirValue value) {
auto it = valueMap.find(jitValue); auto it = valueMap.find(jitValue);
(void)it; (void)it;
assert(it == valueMap.end() && "jitValue has already been mapped"); assert(it == valueMap.end() && "jitValue has already been mapped");
valueMap[jitValue] = value; valueMap[jitValue] = value;
} }
void NodeImporter::mapResults(Node* node, MlirOperation operation) { void NodeImporter::mapResults(Node *node, MlirOperation operation) {
assert( assert(node->outputs().size() ==
node->outputs().size() == (size_t)mlirOperationGetNumResults(operation)); (size_t)mlirOperationGetNumResults(operation));
for (int i = 0, e = node->outputs().size(); i < e; i++) { for (int i = 0, e = node->outputs().size(); i < e; i++) {
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i)); mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
} }
} }
MlirValue NodeImporter::lookupMappedValue(Value* jitValue) { MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
auto it = valueMap.find(jitValue); auto it = valueMap.find(jitValue);
assert( assert(it != valueMap.end() &&
it != valueMap.end() && "trying to get mapping for jitValue that is not mapped yet!");
"trying to get mapping for jitValue that is not mapped yet!");
return it->second; return it->second;
} }
std::vector<MlirValue> std::vector<MlirValue>
NodeImporter::lookupMappedValues(c10::ArrayRef<Value*> values) { NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
std::vector<MlirValue> ret; std::vector<MlirValue> ret;
for (Value* value : values) { for (Value *value : values) {
ret.push_back(lookupMappedValue(value)); ret.push_back(lookupMappedValue(value));
} }
return ret; return ret;
} }
MlirBlock torch_mlir::importBlock( MlirBlock
MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator, torch_mlir::importBlock(MlirContext context, Block *jitBlock,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, CreateTerminatorFn createTerminator,
const ImportOptions& importOptions) { c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) {
NodeImporter importer(context); NodeImporter importer(context);
return importer.importBlock( return importer.importBlock(jitBlock, createTerminator, blockArgTypes,
jitBlock, createTerminator, blockArgTypes, importOptions); importOptions);
} }

View File

@ -36,11 +36,11 @@ using CreateTerminatorFn =
/// are required to be for correctness. The code will internally attempt to /// are required to be for correctness. The code will internally attempt to
/// adjust the types to the block argument types. /// adjust the types to the block argument types.
/// TODO: Formalize what type conversions are allowed here. /// TODO: Formalize what type conversions are allowed here.
MlirBlock importBlock( MlirBlock
MlirContext context, torch::jit::Block* jitBlock, importBlock(MlirContext context, torch::jit::Block *jitBlock,
CreateTerminatorFn createTerminator, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -26,8 +26,8 @@
using namespace torch_mlir; using namespace torch_mlir;
static MlirType getMlirTypeForTorchScalarTypeRaw( static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
MlirContext context, c10::ScalarType scalarType) { c10::ScalarType scalarType) {
using c10::ScalarType; using c10::ScalarType;
switch (scalarType) { switch (scalarType) {
case ScalarType::Byte: case ScalarType::Byte:
@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(
} }
} }
MlirType torch_mlir::getMlirTypeForTorchScalarType( MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
MlirLocation loc, c10::ScalarType scalarType) { c10::ScalarType scalarType) {
auto type = auto type =
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType); getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType(
// There is no generic way to import custom classes (or their types), so we // There is no generic way to import custom classes (or their types), so we
// have to name match them here (and the relevant code in the ivalue // have to name match them here (and the relevant code in the ivalue
// importer) and create special IR constructs for them. // importer) and create special IR constructs for them.
static MlirType mapCustomClassType( static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) { const c10::ClassTypePtr &classType) {
// If the type is unnamed, it cannot be a custom class. // If the type is unnamed, it cannot be a custom class.
if (!classType->name().has_value()) { if (!classType->name().has_value()) {
return {nullptr}; return {nullptr};
@ -126,9 +126,10 @@ static MlirType mapCustomClassType(
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
MlirType torch_mlir::getMlirTypeFromTorchType( MlirType
MlirLocation loc, const c10::TypePtr& torchType, torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
const ImportOptions& importOptions) { const c10::TypePtr &torchType,
const ImportOptions &importOptions) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
using c10::TypeKind; using c10::TypeKind;
auto kind = torchType->kind(); auto kind = torchType->kind();
@ -140,11 +141,10 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
: torchMlirTorchNonValueTensorTypeGet; : torchMlirTorchNonValueTensorTypeGet;
if (importOptions.ignoreExistingTensorShapesAndDtypes) { if (importOptions.ignoreExistingTensorShapesAndDtypes) {
return getMlirTensorType( return getMlirTensorType(context,
context, /*numSizes=*/-1,
/*numSizes=*/-1, /*optionalSizes=*/nullptr,
/*optionalSizes=*/nullptr, /*optionalDtype=*/{nullptr});
/*optionalDtype=*/{nullptr});
} }
// Element type. // Element type.
@ -156,18 +156,17 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
return {nullptr}; return {nullptr};
} }
// Sizes. // Sizes.
auto& sizes = tensorType->symbolic_sizes(); auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) { if (!sizes.rank()) {
// Unranked. // Unranked.
return getMlirTensorType( return getMlirTensorType(context,
context, /*numSizes=*/-1,
/*numSizes=*/-1, /*optionalSizes=*/nullptr,
/*optionalSizes=*/nullptr, /*optionalDtype=*/
/*optionalDtype=*/ elementType);
elementType);
} }
// Ranked with possibly dynamic dims. // Ranked with possibly dynamic dims.
auto& symbolicShape = tensorType->symbolic_sizes(); auto &symbolicShape = tensorType->symbolic_sizes();
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.resize(*sizes.rank()); dims.resize(*sizes.rank());
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
@ -180,12 +179,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
// the C API constructor, when we want the "we know we have 0 sizes" // the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer. // case. So use a dummy data pointer.
int64_t dummy; int64_t dummy;
int64_t* dimsData = dims.size() == 0 ? &dummy : dims.data(); int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
return getMlirTensorType( return getMlirTensorType(context, dims.size(),
context, dims.size(), /*optionalSizes=*/dimsData,
/*optionalSizes=*/dimsData, /*optionalDtype=*/
/*optionalDtype=*/ elementType);
elementType);
} }
case TypeKind::IntType: { case TypeKind::IntType: {
return torchMlirTorchIntTypeGet(context); return torchMlirTorchIntTypeGet(context);
@ -209,22 +207,22 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
} }
case TypeKind::TupleType: { case TypeKind::TupleType: {
std::vector<MlirType> containedTypes; std::vector<MlirType> containedTypes;
for (const c10::TypePtr& type : for (const c10::TypePtr &type :
torchType->cast<c10::TupleType>()->containedTypes()) { torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back( containedTypes.push_back(
getMlirTypeFromTorchType(loc, type, importOptions)); getMlirTypeFromTorchType(loc, type, importOptions));
} }
return torchMlirTorchTupleTypeGet( return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
context, containedTypes.size(), containedTypes.data()); containedTypes.data());
} }
case TypeKind::UnionType: { case TypeKind::UnionType: {
std::vector<MlirType> containedTypes; std::vector<MlirType> containedTypes;
for (const c10::TypePtr& type : for (const c10::TypePtr &type :
torchType->cast<c10::UnionType>()->containedTypes()) { torchType->cast<c10::UnionType>()->containedTypes()) {
containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
} }
return torchMlirTorchUnionTypeGet( return torchMlirTorchUnionTypeGet(context, containedTypes.size(),
context, containedTypes.size(), containedTypes.data()); containedTypes.data());
} }
case TypeKind::ListType: { case TypeKind::ListType: {
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
@ -244,7 +242,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
return torchMlirTorchAnyTypeGet(context); return torchMlirTorchAnyTypeGet(context);
} }
case TypeKind::ClassType: { case TypeKind::ClassType: {
const c10::ClassTypePtr& classType = torchType->cast<c10::ClassType>(); const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
MlirType customClassType = mapCustomClassType(context, loc, classType); MlirType customClassType = mapCustomClassType(context, loc, classType);
if (!mlirTypeIsNull(customClassType)) { if (!mlirTypeIsNull(customClassType)) {
return customClassType; return customClassType;
@ -268,11 +266,12 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
} }
} }
MlirType torch_mlir::getFunctionTypeFromSchema( MlirType
MlirContext context, const c10::FunctionSchema& schema, torch_mlir::getFunctionTypeFromSchema(MlirContext context,
const ImportOptions& importOptions) { const c10::FunctionSchema &schema,
const ImportOptions &importOptions) {
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
auto mapType = [&](const c10::TypePtr& torchType) { auto mapType = [&](const c10::TypePtr &torchType) {
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
std::stringstream msg; std::stringstream msg;
@ -284,20 +283,17 @@ MlirType torch_mlir::getFunctionTypeFromSchema(
}; };
std::vector<MlirType> inputTypes = std::vector<MlirType> inputTypes =
c10::fmap(schema.arguments(), [&](const c10::Argument& arg) { c10::fmap(schema.arguments(),
return mapType(arg.type()); [&](const c10::Argument &arg) { return mapType(arg.type()); });
});
std::vector<MlirType> outputTypes = std::vector<MlirType> outputTypes =
c10::fmap(schema.returns(), [&](const c10::Argument& arg) { c10::fmap(schema.returns(),
return mapType(arg.type()); [&](const c10::Argument &arg) { return mapType(arg.type()); });
}); return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
return mlirFunctionTypeGet( outputTypes.size(), outputTypes.data());
context, inputTypes.size(), inputTypes.data(), outputTypes.size(),
outputTypes.data());
} }
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr( MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
at::Tensor tensor, MlirLocation loc) { MlirLocation loc) {
using at::ScalarType; using at::ScalarType;
auto throwUnsupportedTensorError = [&]() { auto throwUnsupportedTensorError = [&]() {
@ -312,8 +308,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
// The flat number of bytes throws an exception for tensors that are not // The flat number of bytes throws an exception for tensors that are not
// dense and accessible as such. // dense and accessible as such.
at::checkLayout( at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided); c10::Layout::Strided);
// Construct the ShapedType. // Construct the ShapedType.
@ -338,47 +334,47 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
switch (tensor.scalar_type()) { switch (tensor.scalar_type()) {
case ScalarType::Int: case ScalarType::Int:
return mlirDenseElementsAttrInt32Get( return mlirDenseElementsAttrInt32Get(
shapedType, numElements, static_cast<const int32_t*>(tensorData)); shapedType, numElements, static_cast<const int32_t *>(tensorData));
break; break;
case ScalarType::Long: case ScalarType::Long:
return mlirDenseElementsAttrInt64Get( return mlirDenseElementsAttrInt64Get(
shapedType, numElements, static_cast<const int64_t*>(tensorData)); shapedType, numElements, static_cast<const int64_t *>(tensorData));
break; break;
case ScalarType::Float: case ScalarType::Float:
return mlirDenseElementsAttrFloatGet( return mlirDenseElementsAttrFloatGet(
shapedType, numElements, static_cast<const float*>(tensorData)); shapedType, numElements, static_cast<const float *>(tensorData));
break; break;
case ScalarType::Double: case ScalarType::Double:
return mlirDenseElementsAttrDoubleGet( return mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double*>(tensorData)); shapedType, numElements, static_cast<const double *>(tensorData));
break; break;
case ScalarType::Bool: { case ScalarType::Bool: {
// TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed // TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed
// upstream to take in a `const bool *` rather than a `const int *` to avoid // upstream to take in a `const bool *` rather than a `const int *` to avoid
// the unnecessary copying into an array four times as large. // the unnecessary copying into an array four times as large.
const int8_t* elements = static_cast<const int8_t*>(tensorData); const int8_t *elements = static_cast<const int8_t *>(tensorData);
std::vector<int> tensorDataVector(elements, elements + numElements); std::vector<int> tensorDataVector(elements, elements + numElements);
return mlirDenseElementsAttrBoolGet( return mlirDenseElementsAttrBoolGet(shapedType, numElements,
shapedType, numElements, tensorDataVector.data()); tensorDataVector.data());
} break; } break;
case ScalarType::QInt8: case ScalarType::QInt8:
return mlirDenseElementsAttrInt8Get( return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t*>(tensorData)); shapedType, numElements, static_cast<const int8_t *>(tensorData));
case ScalarType::QUInt8: case ScalarType::QUInt8:
return mlirDenseElementsAttrUInt8Get( return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t*>(tensorData)); shapedType, numElements, static_cast<const uint8_t *>(tensorData));
case ScalarType::BFloat16: case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get( return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t*>(tensorData)); shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Half: case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get( return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t*>(tensorData)); shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Byte: case ScalarType::Byte:
return mlirDenseElementsAttrUInt8Get( return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t*>(tensorData)); shapedType, numElements, static_cast<const uint8_t *>(tensorData));
case ScalarType::Char: case ScalarType::Char:
return mlirDenseElementsAttrInt8Get( return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t*>(tensorData)); shapedType, numElements, static_cast<const int8_t *>(tensorData));
default: default:
throwUnsupportedTensorError(); throwUnsupportedTensorError();
@ -386,8 +382,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
return {nullptr}; // Unreachable. return {nullptr}; // Unreachable.
} }
MlirAttribute torch_mlir::importAttribute( MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) { torch::jit::Node *node,
c10::Symbol symbol) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
auto kind = node->kindOf(symbol); auto kind = node->kindOf(symbol);
switch (kind) { switch (kind) {
@ -396,8 +393,8 @@ MlirAttribute torch_mlir::importAttribute(
// do that. // do that.
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol)); return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
case torch::jit::AttributeKind::f: case torch::jit::AttributeKind::f:
return mlirFloatAttrDoubleGet( return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
context, mlirF64TypeGet(context), node->f(symbol)); node->f(symbol));
case torch::jit::AttributeKind::s: case torch::jit::AttributeKind::s:
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol))); return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
case torch::jit::AttributeKind::t: case torch::jit::AttributeKind::t:
@ -411,23 +408,23 @@ MlirAttribute torch_mlir::importAttribute(
} }
} }
MlirLocation torch_mlir::getMlirLocationFromNode( MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
MlirContext context, torch::jit::Node* node) { torch::jit::Node *node) {
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
if (node->hasAttribute(c10::Symbol::attr("source_files"))) { if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
const auto& sourceFiles = node->ss(c10::Symbol::attr("source_files")); const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files"));
const auto& lineNumbers = node->is(c10::Symbol::attr("line_numbers")); const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers"));
const auto& functions = node->ss(c10::Symbol::attr("functions")); const auto &functions = node->ss(c10::Symbol::attr("functions"));
// Chain a sequence of calls to construct single MlirLocation. // Chain a sequence of calls to construct single MlirLocation.
for (const auto i : c10::irange(sourceFiles.size())) { for (const auto i : c10::irange(sourceFiles.size())) {
MlirLocation newLoc = mlirLocationNameGet( MlirLocation newLoc = mlirLocationNameGet(
context, toMlirStringRef(functions[i]), context, toMlirStringRef(functions[i]),
mlirLocationFileLineColGet( mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]),
context, toMlirStringRef(sourceFiles[i]), lineNumbers[i], lineNumbers[i],
0 /* column is not available */ 0 /* column is not available */
)); ));
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
} }
if (sourceFiles.size() == 1) { if (sourceFiles.size() == 1) {
@ -436,7 +433,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(
loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context));
} }
} else if (auto flc = node->sourceRange().file_line_col()) { } else if (auto flc = node->sourceRange().file_line_col()) {
const std::string& file = std::get<0>(*flc); const std::string &file = std::get<0>(*flc);
int line = std::get<1>(*flc); int line = std::get<1>(*flc);
int col = std::get<2>(*flc); int col = std::get<2>(*flc);
loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col);
@ -448,7 +445,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(
locationName = scopeName; locationName = scopeName;
} }
if (const c10::FunctionSchema* schema = node->maybeSchema()) { if (const c10::FunctionSchema *schema = node->maybeSchema()) {
if (!locationName.empty()) { if (!locationName.empty()) {
locationName += "/"; locationName += "/";
} }
@ -462,9 +459,10 @@ MlirLocation torch_mlir::getMlirLocationFromNode(
return loc; return loc;
} }
std::vector<MlirType> torch_mlir::getMlirTypesFromValues( std::vector<MlirType>
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values, torch_mlir::getMlirTypesFromValues(MlirLocation loc,
const ImportOptions& importOptions) { c10::ArrayRef<torch::jit::Value *> values,
const ImportOptions &importOptions) {
std::vector<MlirType> ret; std::vector<MlirType> ret;
for (auto value : values) { for (auto value : values) {
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
@ -493,24 +491,25 @@ std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
} }
std::stringstream msg; std::stringstream msg;
MlirStringCallback printToStream = +[](MlirStringRef str, void* userData) { MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
std::stringstream* stream = static_cast<std::stringstream*>(userData); std::stringstream *stream = static_cast<std::stringstream *>(userData);
stream->write(str.data, str.length); stream->write(str.data, str.length);
}; };
msg << "unhandled: could not adjust static info for type from "; msg << "unhandled: could not adjust static info for type from ";
mlirTypePrint(type, printToStream, static_cast<void*>(&msg)); mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
msg << " to type "; msg << " to type ";
mlirTypePrint(expectedType, printToStream, static_cast<void*>(&msg)); mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
mlirEmitError(loc, msg.str().c_str()); mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
return ret; return ret;
} }
MlirOperation torch_mlir::createOperationFromSchema( MlirOperation
MlirBlock appendToBlock, MlirLocation loc, torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc,
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes, const c10::FunctionSchema &schema,
c10::ArrayRef<MlirValue> operands) { c10::ArrayRef<MlirType> resultTypes,
c10::ArrayRef<MlirValue> operands) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
// Munge the name into the appropriate MLIR operation name. // Munge the name into the appropriate MLIR operation name.
@ -520,15 +519,15 @@ MlirOperation torch_mlir::createOperationFromSchema(
auto separatorPosition = opNameSuffix.find_first_of("::"); auto separatorPosition = opNameSuffix.find_first_of("::");
assert(separatorPosition != std::string::npos); assert(separatorPosition != std::string::npos);
opNameSuffix.replace(separatorPosition, 2, "."); opNameSuffix.replace(separatorPosition, 2, ".");
const std::string& overloadName = schema.overload_name(); const std::string &overloadName = schema.overload_name();
if (!overloadName.empty()) { if (!overloadName.empty()) {
opNameSuffix = opNameSuffix + "." + overloadName; opNameSuffix = opNameSuffix + "." + overloadName;
} }
std::string opName = "torch." + opNameSuffix; std::string opName = "torch." + opNameSuffix;
// If we have a registered op, use it! // If we have a registered op, use it!
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) { if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
return createMlirOperationAtEnd( return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes,
appendToBlock, opName, loc, resultTypes, operands); operands);
} }
// Oops, no registered op -- create an opaque wrapper so that import can // Oops, no registered op -- create an opaque wrapper so that import can
// still succeed. This helps a common use case of filling out registered ops // still succeed. This helps a common use case of filling out registered ops

View File

@ -25,7 +25,7 @@ namespace torch_mlir {
/// Thrown on failure when details are in MLIR emitted diagnostics. /// Thrown on failure when details are in MLIR emitted diagnostics.
class mlir_diagnostic_emitted : public std::runtime_error { class mlir_diagnostic_emitted : public std::runtime_error {
public: public:
mlir_diagnostic_emitted(const char* what) : std::runtime_error(what) {} mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
}; };
@ -38,36 +38,37 @@ public:
/// for Python code). /// for Python code).
/// ///
/// Returns a null type on failure and emits a diagnostic. /// Returns a null type on failure and emits a diagnostic.
MlirType MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType); c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type /// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic. /// on failure and emits a diagnostic.
MlirType getMlirTypeFromTorchType( MlirType getMlirTypeFromTorchType(MlirLocation loc,
MlirLocation loc, const c10::TypePtr& torchType, const c10::TypePtr &torchType,
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
/// Creates a FunctionType suitable for expressing the signature of `schema`. /// Creates a FunctionType suitable for expressing the signature of `schema`.
/// ///
/// This can differ from the type inferred from the block of a /// This can differ from the type inferred from the block of a
/// torch::jit::Function due to derefinement and refinement of tensor types. /// torch::jit::Function due to derefinement and refinement of tensor types.
MlirType getFunctionTypeFromSchema( MlirType getFunctionTypeFromSchema(MlirContext context,
MlirContext context, const c10::FunctionSchema& schema, const c10::FunctionSchema &schema,
const ImportOptions& importOptions = {}); const ImportOptions &importOptions = {});
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`. /// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc); MlirLocation loc);
MlirAttribute MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node,
importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol); c10::Symbol symbol);
MlirLocation MlirLocation getMlirLocationFromNode(MlirContext context,
getMlirLocationFromNode(MlirContext context, torch::jit::Node* node); torch::jit::Node *node);
std::vector<MlirType> getMlirTypesFromValues( std::vector<MlirType>
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values, getMlirTypesFromValues(MlirLocation loc,
const ImportOptions& importOptions = {}); c10::ArrayRef<torch::jit::Value *> values,
const ImportOptions &importOptions = {});
std::vector<MlirValue> adjustStaticInformationForValues( std::vector<MlirValue> adjustStaticInformationForValues(
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values, MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
@ -78,10 +79,11 @@ std::vector<MlirValue> adjustStaticInformationForValues(
/// ///
/// The primary difficulty here is doing the appropriate name munging and /// The primary difficulty here is doing the appropriate name munging and
/// checking if the have a registered op. /// checking if the have a registered op.
MlirOperation createOperationFromSchema( MlirOperation createOperationFromSchema(MlirBlock appendToBlock,
MlirBlock appendToBlock, MlirLocation loc, MlirLocation loc,
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes, const c10::FunctionSchema &schema,
c10::ArrayRef<MlirValue> operands); c10::ArrayRef<MlirType> resultTypes,
c10::ArrayRef<MlirValue> operands);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -1,30 +1,3 @@
# Static library with core functionality.
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
add_library(TorchMLIRJITIRImporter STATIC
class_annotator.cpp
function_importer.cpp
node_importer.cpp
ivalue_importer.cpp
torch_to_mlir_utils.cpp
)
target_link_libraries(TorchMLIRJITIRImporter
TorchMLIRAggregateCAPI
${TORCH_LIBRARIES}
)
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
target_include_directories(TorchMLIRJITIRImporter PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/..
)
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
OUTPUT_NAME lib_jit_ir_importer
PREFIX ""
SUFFIX ".a"
CXX_VISIBILITY_PRESET "default"
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
)
# Separate Pybind MODULE due to issues with a SHARED library. # Separate Pybind MODULE due to issues with a SHARED library.
# https://github.com/llvm/torch-mlir/issues/1154 # https://github.com/llvm/torch-mlir/issues/1154
add_library(TorchMLIRJITIRImporterPybind MODULE add_library(TorchMLIRJITIRImporterPybind MODULE

View File

@ -8,7 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "class_annotator_pybind.h" #include "class_annotator_pybind.h"
#include "class_annotator.h" #include "jit_ir_importer/class_annotator.h"
#include <torch/csrc/Dtype.h> #include <torch/csrc/Dtype.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>

View File

@ -8,7 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "import_options_pybind.h" #include "import_options_pybind.h"
#include "import_options.h" #include "jit_ir_importer/import_options.h"
namespace py = pybind11; namespace py = pybind11;

View File

@ -9,9 +9,9 @@
#include "module_builder.h" #include "module_builder.h"
#include "function_importer.h" #include "jit_ir_importer/function_importer.h"
#include "ivalue_importer.h" #include "jit_ir_importer/ivalue_importer.h"
#include "mlir_utils.h" #include "jit_ir_importer/mlir_utils.h"
#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"

View File

@ -10,7 +10,7 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H #ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H #define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
#include "class_annotator.h" #include "jit_ir_importer/class_annotator.h"
#include "mlir-c/IR.h" #include "mlir-c/IR.h"