Get LTC building.

breakup_python_pytorch_deps
Stella Laurenzo 2023-11-18 17:56:00 -08:00
parent fabb4d6e5d
commit 606dc45896
31 changed files with 508 additions and 480 deletions

2
.gitignore vendored
View File

@ -26,7 +26,7 @@ __pycache__
bazel-* bazel-*
# Autogenerated files # Autogenerated files
/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated /projects/ltc/csrc/base_lazy_backend/generated
#Docker builds #Docker builds
build_oot/ build_oot/

View File

@ -29,7 +29,6 @@ if not TORCH_INCLUDE_DIR.is_dir():
TORCH_INCLUDE_DIR = TORCH_DIR TORCH_INCLUDE_DIR = TORCH_DIR
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1"
def reindent(text, prefix=""): def reindent(text, prefix=""):
return indent(dedent(text), prefix) return indent(dedent(text), prefix)
@ -114,12 +113,12 @@ class GenTorchMlirLTC:
self.binary_dir = Path(binary_dir) self.binary_dir = Path(binary_dir)
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}" assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml") self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
self.backend_path = TORCH_MLIR_PT1_DIR.joinpath( self.backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "base_lazy_backend" "projects", "ltc", "csrc", "base_lazy_backend"
) )
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}" assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
self.generated_path = self.binary_dir.joinpath( self.generated_path = self.binary_dir.joinpath(
"projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated" "projects", "ltc", "csrc", "base_lazy_backend", "generated"
) )
self.generated_path.mkdir(parents=True, exist_ok=True) self.generated_path.mkdir(parents=True, exist_ok=True)
@ -415,7 +414,7 @@ class GenTorchMlirLTC:
// for ops that dont have a corresponding structured kernel or shape definition // for ops that dont have a corresponding structured kernel or shape definition
#include "shape_inference.h" #include "shape_inference.h"
#include "torch_mlir/csrc/base_lazy_backend/utils/exception.h" #include "base_lazy_backend/utils/exception.h"
namespace torch {{ namespace torch {{
namespace lazy {{ namespace lazy {{
{} {}
@ -467,7 +466,7 @@ class GenTorchMlirLTC:
node_base="torch::lazy::TorchMlirNode", node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
tensor_class=self.tensor_class, tensor_class=self.tensor_class,
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", tensor_class_hdr="base_lazy_backend/tensor.h",
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
lazy_ir_generator=GenMlirLazyIr, lazy_ir_generator=GenMlirLazyIr,

View File

@ -12,7 +12,7 @@
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR. [Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation. After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
Implementations based on this abstract class will be able to specify their own compile and execution workflows. Implementations based on this abstract class will be able to specify their own compile and execution workflows.
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend). Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
@ -27,7 +27,7 @@ View examples [here](ltc_examples.md).
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td), - The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
excluding those explicitly blacklisted in the YAML file excluding those explicitly blacklisted in the YAML file
### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated)) ### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated))
Generated files are created in this directory, which is ignored by version control. Generated files are created in this directory, which is ignored by version control.
- `LazyIr.h` - `LazyIr.h`
@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr
- `shape_inference.{cpp,h}` - `shape_inference.{cpp,h}`
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions - Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend)) ### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend))
- `backend_impl.{cpp,h}` - `backend_impl.{cpp,h}`
- Base LTC backend to setup Torch-MLIR lowering context - Base LTC backend to setup Torch-MLIR lowering context

View File

@ -56,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED
utils/tensor_utils.cpp utils/tensor_utils.cpp
) )
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/...").
# Add both the source and generated include directories.
target_include_directories(torch_mlir_ltc_backend PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_BINARY_DIR}/..
)
add_dependencies(torch_mlir_ltc_backend add_dependencies(torch_mlir_ltc_backend
TorchMLIRJITIRImporter TorchMLIRJITIRImporter
@ -88,13 +94,13 @@ add_custom_command(
add_custom_command( add_custom_command(
TARGET torch_mlir_ltc_backend POST_BUILD TARGET torch_mlir_ltc_backend POST_BUILD
COMMAND cp COMMAND cp
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/) ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
add_custom_command( add_custom_command(
TARGET torch_mlir_ltc_backend POST_BUILD TARGET torch_mlir_ltc_backend POST_BUILD
COMMAND cp COMMAND cp
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/) ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
add_custom_command( add_custom_command(
@ -105,7 +111,7 @@ add_custom_command(
add_custom_command( add_custom_command(
TARGET torch_mlir_ltc_backend POST_BUILD TARGET torch_mlir_ltc_backend POST_BUILD
COMMAND cp COMMAND cp
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/) ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
add_custom_command( add_custom_command(
@ -116,5 +122,5 @@ add_custom_command(
add_custom_command( add_custom_command(
TARGET torch_mlir_ltc_backend POST_BUILD TARGET torch_mlir_ltc_backend POST_BUILD
COMMAND cp COMMAND cp
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/) ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)

View File

@ -21,8 +21,8 @@
#include "mlir-c/IR.h" #include "mlir-c/IR.h"
#include "mlir-c/Pass.h" #include "mlir-c/Pass.h"
#include "../../jit_ir_importer/csrc/function_importer.h"
#include "backend_impl.h" #include "backend_impl.h"
#include "jit_ir_importer/function_importer.h"
#include "mlir_lowering_context.h" #include "mlir_lowering_context.h"
#include "mlir_node.h" #include "mlir_node.h"
#include "utils/debug.h" #include "utils/debug.h"

View File

@ -100,6 +100,7 @@ endif()
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
add_subdirectory(torch_mlir/jit_ir_importer) add_subdirectory(torch_mlir/jit_ir_importer)
add_subdirectory(torch_mlir/csrc/jit_ir_importer)
add_subdirectory(torch_mlir_e2e_test) add_subdirectory(torch_mlir_e2e_test)
endif() endif()

View File

@ -12,6 +12,10 @@ target_link_libraries(TorchMLIRJITIRImporter
TorchMLIRAggregateCAPI TorchMLIRAggregateCAPI
${TORCH_LIBRARIES} ${TORCH_LIBRARIES}
) )
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
target_include_directories(TorchMLIRJITIRImporter PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/..
)
set_target_properties(TorchMLIRJITIRImporter PROPERTIES set_target_properties(TorchMLIRJITIRImporter PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
OUTPUT_NAME lib_jit_ir_importer OUTPUT_NAME lib_jit_ir_importer

View File

@ -18,8 +18,8 @@ using namespace torch_mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Prefix every line of `s` with `linePrefix`. // Prefix every line of `s` with `linePrefix`.
static std::string indentString(const std::string &linePrefix, static std::string
const std::string &s) { indentString(const std::string& linePrefix, const std::string& s) {
std::stringstream is(s); std::stringstream is(s);
std::stringstream os; std::stringstream os;
std::string line; std::string line;
@ -39,26 +39,28 @@ ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType)
methodAnnotations.resize(classType->methods().size()); methodAnnotations.resize(classType->methods().size());
} }
std::vector<AttributeAnnotation> &ClassAnnotation::getAttributeAnnotations() { std::vector<AttributeAnnotation>& ClassAnnotation::getAttributeAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has // Halfhearted attempt to ensure consistency if the class type has
// been mutated. // been mutated.
// //
// We can't easily guard against attributes being removed and // We can't easily guard against attributes being removed and
// then other attributes being added, or types changed, etc. without // then other attributes being added, or types changed, etc. without
// effectively mirroring the entire ClassType. // effectively mirroring the entire ClassType.
assert(attributeAnnotations.size() == classType->getAttributes().size() && assert(
"annotations out of sync. class has been mutated"); attributeAnnotations.size() == classType->getAttributes().size() &&
"annotations out of sync. class has been mutated");
return attributeAnnotations; return attributeAnnotations;
} }
std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() { std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() {
// Halfhearted attempt to ensure consistency if the class type has // Halfhearted attempt to ensure consistency if the class type has
// been mutated. // been mutated.
// //
// We can't easily guard against methods being removed, added, or changed. // We can't easily guard against methods being removed, added, or changed.
assert(methodAnnotations.size() == classType->methods().size() && assert(
"annotations out of sync. class has been mutated"); methodAnnotations.size() == classType->methods().size() &&
"annotations out of sync. class has been mutated");
return methodAnnotations; return methodAnnotations;
} }
@ -67,17 +69,17 @@ std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() {
// ClassAnnotator // ClassAnnotator
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void exportNoneRecurse(ClassAnnotator &classAnnotator, static void
c10::ClassType *classType) { exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) {
ClassAnnotation &classAnnotation = ClassAnnotation& classAnnotation =
classAnnotator.getOrCreateClassAnnotation(classType); classAnnotator.getOrCreateClassAnnotation(classType);
for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) { for (auto& attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
attributeAnnotation.isExported = false; attributeAnnotation.isExported = false;
} }
for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) { for (auto& methodAnnotation : classAnnotation.getMethodAnnotations()) {
methodAnnotation.isExported = false; methodAnnotation.isExported = false;
} }
for (auto &classAttribute : classType->getAttributes()) { for (auto& classAttribute : classType->getAttributes()) {
if (auto childClassType = if (auto childClassType =
classAttribute.getType()->cast<c10::ClassType>()) { classAttribute.getType()->cast<c10::ClassType>()) {
exportNoneRecurse(classAnnotator, childClassType.get()); exportNoneRecurse(classAnnotator, childClassType.get());
@ -85,20 +87,20 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator,
} }
} }
void ClassAnnotator::exportNone(c10::ClassType &rootClassType) { void ClassAnnotator::exportNone(c10::ClassType& rootClassType) {
exportNoneRecurse(*this, &rootClassType); exportNoneRecurse(*this, &rootClassType);
} }
void ClassAnnotator::exportPath(c10::ClassType &rootClassType, void ClassAnnotator::exportPath(
std::vector<std::string> exportedPath) { c10::ClassType& rootClassType, std::vector<std::string> exportedPath) {
if (exportedPath.size() == 0) { if (exportedPath.size() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"Empty exported path. Can only export a property of a class."); "Empty exported path. Can only export a property of a class.");
} }
c10::ClassType *classType = c10::ClassType* classType = getClassAtPath(
getClassAtPath(&rootClassType, c10::ArrayRef<std::string>(exportedPath) &rootClassType, c10::ArrayRef<std::string>(exportedPath)
.slice(0, exportedPath.size() - 1) .slice(0, exportedPath.size() - 1)
.vec()); .vec());
if (!classType->findAttribute(exportedPath.back()) && if (!classType->findAttribute(exportedPath.back()) &&
!classType->findMethod(exportedPath.back())) { !classType->findMethod(exportedPath.back())) {
@ -108,10 +110,10 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
<< exportedPath.back() << "'"; << exportedPath.back() << "'";
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType);
std::vector<AttributeAnnotation> &attributeAnnotations = std::vector<AttributeAnnotation>& attributeAnnotations =
classAnnotation.getAttributeAnnotations(); classAnnotation.getAttributeAnnotations();
const std::vector<c10::ClassAttribute> &classAttributes = const std::vector<c10::ClassAttribute>& classAttributes =
classType->getAttributes(); classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
if (classAttributes[i].getName() == exportedPath.back()) { if (classAttributes[i].getName() == exportedPath.back()) {
@ -119,9 +121,9 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
} }
} }
std::vector<MethodAnnotation> &methodAnnotations = std::vector<MethodAnnotation>& methodAnnotations =
classAnnotation.getMethodAnnotations(); classAnnotation.getMethodAnnotations();
const std::vector<torch::jit::Function *> &methods = classType->methods(); const std::vector<torch::jit::Function*>& methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
if (methods[i]->name() == exportedPath.back()) { if (methods[i]->name() == exportedPath.back()) {
methodAnnotations[i].isExported = true; methodAnnotations[i].isExported = true;
@ -129,12 +131,12 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
} }
} }
const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() { const ClassAnnotationMap& ClassAnnotator::getAnnotationMap() {
return classAnnotations; return classAnnotations;
} }
ClassAnnotation & ClassAnnotation&
ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) {
auto className = classType->name()->qualifiedName(); auto className = classType->name()->qualifiedName();
auto it = classAnnotations.find(className); auto it = classAnnotations.find(className);
if (it == classAnnotations.end()) { if (it == classAnnotations.end()) {
@ -149,39 +151,39 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
return *it->second; return *it->second;
} }
static void fillArgAnnotations(MethodAnnotation &methodAnnotation, static void fillArgAnnotations(
std::vector<ArgAnnotation> argAnnotations, MethodAnnotation& methodAnnotation,
torch::jit::Function *function) { std::vector<ArgAnnotation> argAnnotations, torch::jit::Function* function) {
if (argAnnotations.size() != function->num_inputs()) { if (argAnnotations.size() != function->num_inputs()) {
throw std::invalid_argument("Arg annotations should have one entry per " throw std::invalid_argument("Arg annotations should have one entry per "
"function parameter (including self)."); "function parameter (including self).");
} }
if (!methodAnnotation.argAnnotations.has_value()) { if (!methodAnnotation.argAnnotations.has_value()) {
methodAnnotation.argAnnotations.emplace(function->num_inputs(), methodAnnotation.argAnnotations.emplace(
ArgAnnotation{}); function->num_inputs(), ArgAnnotation{});
} }
methodAnnotation.argAnnotations = argAnnotations; methodAnnotation.argAnnotations = argAnnotations;
} }
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType, void ClassAnnotator::annotateArgs(
std::vector<std::string> path, c10::ClassType& rootClassType, std::vector<std::string> path,
std::vector<ArgAnnotation> argAnnotations) { std::vector<ArgAnnotation> argAnnotations) {
if (path.size() == 0) { if (path.size() == 0) {
throw std::invalid_argument("Empty annotated path. Can only annotate " throw std::invalid_argument("Empty annotated path. Can only annotate "
"shapes/dtypes of a method of a class."); "shapes/dtypes of a method of a class.");
} }
c10::ClassType *classType = getClassAtPath( c10::ClassType* classType = getClassAtPath(
&rootClassType, &rootClassType,
c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec()); c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec());
// Throw error if no method on the class of the specified name. // Throw error if no method on the class of the specified name.
torch::jit::Function *function = &classType->getMethod(path.back()); torch::jit::Function* function = &classType->getMethod(path.back());
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType); ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType);
std::vector<MethodAnnotation> &methodAnnotations = std::vector<MethodAnnotation>& methodAnnotations =
classAnnotation.getMethodAnnotations(); classAnnotation.getMethodAnnotations();
const std::vector<torch::jit::Function *> &methods = classType->methods(); const std::vector<torch::jit::Function*>& methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
if (methods[i]->name() == path.back()) { if (methods[i]->name() == path.back()) {
fillArgAnnotations(methodAnnotations[i], argAnnotations, function); fillArgAnnotations(methodAnnotations[i], argAnnotations, function);
@ -191,9 +193,9 @@ void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
return; return;
} }
c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType, c10::ClassType* ClassAnnotator::getClassAtPath(
std::vector<std::string> path) { c10::ClassType* rootClassType, std::vector<std::string> path) {
c10::ClassType *classType = rootClassType; c10::ClassType* classType = rootClassType;
// Reverse so that pop_back gives us the initial atoms first. // Reverse so that pop_back gives us the initial atoms first.
std::reverse(path.begin(), path.end()); std::reverse(path.begin(), path.end());
while (!path.empty()) { while (!path.empty()) {
@ -215,8 +217,8 @@ c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Helper methods // Helper methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
MethodAnnotation * MethodAnnotation*
ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) { ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) {
auto it = functionToMethodMap.find(function); auto it = functionToMethodMap.find(function);
if (it == functionToMethodMap.end()) { if (it == functionToMethodMap.end()) {
return nullptr; return nullptr;
@ -228,7 +230,7 @@ ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) {
// toString methods // toString methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
std::string AttributeAnnotation::toString(const std::string &name) { std::string AttributeAnnotation::toString(const std::string& name) {
std::stringstream ss; std::stringstream ss;
ss << "AttributeAnnotation('" << name << "') {\n"; ss << "AttributeAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n";
@ -259,7 +261,7 @@ std::string ArgAnnotation::toString(int argIndex) {
return ss.str(); return ss.str();
} }
std::string MethodAnnotation::toString(const std::string &name) { std::string MethodAnnotation::toString(const std::string& name) {
std::stringstream ss; std::stringstream ss;
ss << "MethodAnnotation('" << name << "') {\n"; ss << "MethodAnnotation('" << name << "') {\n";
ss << " isExported = " << (isExported ? "true" : "false") << "\n"; ss << " isExported = " << (isExported ? "true" : "false") << "\n";
@ -280,13 +282,13 @@ std::string ClassAnnotation::toString() {
std::stringstream ss; std::stringstream ss;
ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n"; ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n";
const std::vector<c10::ClassAttribute> &classAttributes = const std::vector<c10::ClassAttribute>& classAttributes =
classType->getAttributes(); classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
ss << indentString( ss << indentString(
" ", attributeAnnotations[i].toString(classAttributes[i].getName())); " ", attributeAnnotations[i].toString(classAttributes[i].getName()));
} }
const std::vector<torch::jit::Function *> &methods = classType->methods(); const std::vector<torch::jit::Function*>& methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name())); ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name()));
} }
@ -297,7 +299,7 @@ std::string ClassAnnotation::toString() {
std::string ClassAnnotator::toString() { std::string ClassAnnotator::toString() {
std::stringstream ss; std::stringstream ss;
ss << "ClassAnnotator {\n"; ss << "ClassAnnotator {\n";
for (auto &p : classAnnotations) { for (auto& p : classAnnotations) {
ss << indentString(" ", p.second->toString()); ss << indentString(" ", p.second->toString());
} }
ss << "}\n"; ss << "}\n";

View File

@ -34,7 +34,7 @@ struct AttributeAnnotation {
// can be externally accessed. // can be externally accessed.
bool isExported = true; bool isExported = true;
std::string toString(const std::string &name); std::string toString(const std::string& name);
}; };
// An annotation of an argument of a method. // An annotation of an argument of a method.
@ -80,7 +80,7 @@ struct MethodAnnotation {
// large printout of the default ArgAnnotation for every method. // large printout of the default ArgAnnotation for every method.
c10::optional<std::vector<ArgAnnotation>> argAnnotations; c10::optional<std::vector<ArgAnnotation>> argAnnotations;
std::string toString(const std::string &name); std::string toString(const std::string& name);
}; };
// Annotations on a c10::ClassType. // Annotations on a c10::ClassType.
@ -107,10 +107,10 @@ public:
// Get the attribute annotations. // Get the attribute annotations.
// The length and order is the same as `classType->getAttributes()`. // The length and order is the same as `classType->getAttributes()`.
std::vector<AttributeAnnotation> &getAttributeAnnotations(); std::vector<AttributeAnnotation>& getAttributeAnnotations();
// Get the method annotations. // Get the method annotations.
// The length and order is the same as `classType->methods()`. // The length and order is the same as `classType->methods()`.
std::vector<MethodAnnotation> &getMethodAnnotations(); std::vector<MethodAnnotation>& getMethodAnnotations();
std::string toString(); std::string toString();
@ -141,14 +141,14 @@ public:
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should // For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
// have a submodule `a` and that submodule should have a method or attribute // have a submodule `a` and that submodule should have a method or attribute
// `b`. // `b`.
void exportPath(c10::ClassType &rootClassType, void exportPath(
std::vector<std::string> exportedPath); c10::ClassType& rootClassType, std::vector<std::string> exportedPath);
// Mark everything as not-exported. // Mark everything as not-exported.
// //
// This is kind of useless by itself, but together with `exportPath` allows // This is kind of useless by itself, but together with `exportPath` allows
// exporting a subset of known names out of a larger collection of unknown // exporting a subset of known names out of a larger collection of unknown
// names. // names.
void exportNone(c10::ClassType &rootClassType); void exportNone(c10::ClassType& rootClassType);
// Annotate shapes and dtypes of the arguments of a method at path `path` from // Annotate shapes and dtypes of the arguments of a method at path `path` from
// `rootClassType`. // `rootClassType`.
@ -159,23 +159,23 @@ public:
// a "has value semantics" boolean. // a "has value semantics" boolean.
// These will be put into an `ArgAnnotation` struct -- see there for // These will be put into an `ArgAnnotation` struct -- see there for
// precise definitions of the promised semantics of each entry. // precise definitions of the promised semantics of each entry.
void annotateArgs(c10::ClassType &rootClassType, void annotateArgs(
std::vector<std::string> path, c10::ClassType& rootClassType, std::vector<std::string> path,
std::vector<ArgAnnotation> argAnnotations); std::vector<ArgAnnotation> argAnnotations);
// The annotations collected so far. // The annotations collected so far.
const ClassAnnotationMap &getAnnotationMap(); const ClassAnnotationMap& getAnnotationMap();
// Get the ClassAnnotation corresponding to `classType`. // Get the ClassAnnotation corresponding to `classType`.
ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType); ClassAnnotation& getOrCreateClassAnnotation(c10::ClassType* classType);
// Helper to find the MethodAnnotation corresponding to a // Helper to find the MethodAnnotation corresponding to a
// torch::jit::Function, or null if not found. // torch::jit::Function, or null if not found.
// //
// Users could in principle scan all annotations to find this, but it's more // Users could in principle scan all annotations to find this, but it's more
// efficient to maintain the reverse mapping directly. // efficient to maintain the reverse mapping directly.
MethodAnnotation * MethodAnnotation*
getMethodAnnotationForFunction(torch::jit::Function *function); getMethodAnnotationForFunction(torch::jit::Function* function);
std::string toString(); std::string toString();
@ -183,11 +183,11 @@ private:
// Traverse `path` starting from `rootClassType` to find the ClassType // Traverse `path` starting from `rootClassType` to find the ClassType
// of a presumed nested submodule. Throw an error if there is no such // of a presumed nested submodule. Throw an error if there is no such
// submodule. // submodule.
c10::ClassType *getClassAtPath(c10::ClassType *rootClassType, c10::ClassType*
std::vector<std::string> path); getClassAtPath(c10::ClassType* rootClassType, std::vector<std::string> path);
ClassAnnotationMap classAnnotations; ClassAnnotationMap classAnnotations;
// Reverse mapping used to service getMethodAnnotationForFunction. // Reverse mapping used to service getMethodAnnotationForFunction.
std::unordered_map<torch::jit::Function *, MethodAnnotation *> std::unordered_map<torch::jit::Function*, MethodAnnotation*>
functionToMethodMap; functionToMethodMap;
}; };

View File

@ -18,7 +18,7 @@ using namespace torch_mlir;
static c10::ScalarType convertToC10ScalarType(py::object obj) { static c10::ScalarType convertToC10ScalarType(py::object obj) {
if (THPDtype_Check(obj.ptr())) { if (THPDtype_Check(obj.ptr())) {
// Need reinterpret_cast, since no C++-level inheritance is involved. // Need reinterpret_cast, since no C++-level inheritance is involved.
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr()); THPDtype* dtype = reinterpret_cast<THPDtype*>(obj.ptr());
return dtype->scalar_type; return dtype->scalar_type;
} }
std::stringstream ss; std::stringstream ss;
@ -48,16 +48,17 @@ static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
return argAnnotations; return argAnnotations;
} }
void torch_mlir::initClassAnnotatorBindings(py::module &m) { void torch_mlir::initClassAnnotatorBindings(py::module& m) {
py::class_<ClassAnnotator>(m, "ClassAnnotator") py::class_<ClassAnnotator>(m, "ClassAnnotator")
.def(py::init<>()) .def(py::init<>())
.def("exportPath", &ClassAnnotator::exportPath) .def("exportPath", &ClassAnnotator::exportPath)
.def("exportNone", &ClassAnnotator::exportNone) .def("exportNone", &ClassAnnotator::exportNone)
.def("annotateArgs", .def(
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType, "annotateArgs",
std::vector<std::string> path, py::list argAnnotations) { [&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType,
cls_annotator.annotateArgs(rootClassType, path, std::vector<std::string> path, py::list argAnnotations) {
getArgAnnotations(argAnnotations)); cls_annotator.annotateArgs(
}) rootClassType, path, getArgAnnotations(argAnnotations));
})
.def("__repr__", &ClassAnnotator::toString); .def("__repr__", &ClassAnnotator::toString);
} }

View File

@ -18,7 +18,7 @@
namespace py = pybind11; namespace py = pybind11;
namespace torch_mlir { namespace torch_mlir {
void initClassAnnotatorBindings(py::module &m); void initClassAnnotatorBindings(py::module& m);
} // namespace torch_mlir } // namespace torch_mlir
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H #endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H

View File

@ -21,9 +21,9 @@
using namespace torch_mlir; using namespace torch_mlir;
MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function, MlirContext context, torch::jit::Function* function,
std::function<MlirAttribute(int)> getArgAttribute, std::function<MlirAttribute(int)> getArgAttribute,
const ImportOptions &importOptions) { const ImportOptions& importOptions) {
// Useful for debugging: // Useful for debugging:
// graph->dump(); // graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
@ -63,10 +63,11 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
} }
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues, auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd(appendToBlock, "func.return", loc, createMlirOperationAtEnd(
adjustStaticInformationForValues( appendToBlock, "func.return", loc,
appendToBlock, loc, yieldedValues, resultTypes, adjustStaticInformationForValues(
/*userAllowsRefinement=*/false)); appendToBlock, loc, yieldedValues, resultTypes,
/*userAllowsRefinement=*/false));
}; };
MlirBlock block = importBlock( MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(), context, torch::jit::toGraphFunction(*function).graph()->block(),

View File

@ -40,10 +40,10 @@ namespace torch_mlir {
/// null MlirAttribute is returned, no attribute will be attached to that /// null MlirAttribute is returned, no attribute will be attached to that
/// argument. /// argument.
MlirOperation importJitFunctionAsFuncOp( MlirOperation importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function, MlirContext context, torch::jit::Function* function,
std::function<MlirAttribute(int)> getArgAttribute = std::function<MlirAttribute(int)> getArgAttribute =
[](int) -> MlirAttribute { return {nullptr}; }, [](int) -> MlirAttribute { return {nullptr}; },
const ImportOptions &importOptions = {}); const ImportOptions& importOptions = {});
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -50,9 +50,9 @@ static py::list getRegisteredOps() {
// since the JIT has its own dispatch mechanism that it uses to implement // since the JIT has its own dispatch mechanism that it uses to implement
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such // "prim" ops and a handful of "aten" ops that are effectively prim ops, such
// as `aten::__is__`. // as `aten::__is__`.
for (const std::shared_ptr<torch::jit::Operator> &op : for (const std::shared_ptr<torch::jit::Operator>& op :
torch::jit::getAllOperators()) { torch::jit::getAllOperators()) {
const c10::FunctionSchema &schema = op->schema(); const c10::FunctionSchema& schema = op->schema();
py::dict record; py::dict record;
{ {
@ -69,7 +69,7 @@ static py::list getRegisteredOps() {
py::list arguments; py::list arguments;
py::list returns; py::list returns;
auto addArgument = [](py::list &container, const c10::Argument &arg) { auto addArgument = [](py::list& container, const c10::Argument& arg) {
py::dict argRecord; py::dict argRecord;
argRecord["name"] = arg.name(); argRecord["name"] = arg.name();
argRecord["type"] = arg.type()->str(); argRecord["type"] = arg.type()->str();
@ -87,10 +87,10 @@ static py::list getRegisteredOps() {
py::dict aliasInfo; py::dict aliasInfo;
py::list before; py::list before;
py::list after; py::list after;
for (auto &symbol : arg.alias_info()->beforeSets()) { for (auto& symbol : arg.alias_info()->beforeSets()) {
before.append(std::string(symbol.toQualString())); before.append(std::string(symbol.toQualString()));
} }
for (auto &symbol : arg.alias_info()->afterSets()) { for (auto& symbol : arg.alias_info()->afterSets()) {
after.append(std::string(symbol.toQualString())); after.append(std::string(symbol.toQualString()));
} }
aliasInfo["is_write"] = arg.alias_info()->isWrite(); aliasInfo["is_write"] = arg.alias_info()->isWrite();
@ -101,10 +101,10 @@ static py::list getRegisteredOps() {
container.append(std::move(argRecord)); container.append(std::move(argRecord));
}; };
for (auto &argument : schema.arguments()) { for (auto& argument : schema.arguments()) {
addArgument(arguments, argument); addArgument(arguments, argument);
} }
for (auto &returnArg : schema.returns()) { for (auto& returnArg : schema.returns()) {
addArgument(returns, returnArg); addArgument(returns, returnArg);
} }
record["arguments"] = std::move(arguments); record["arguments"] = std::move(arguments);
@ -115,6 +115,6 @@ static py::list getRegisteredOps() {
return results; return results;
} }
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) { void torch_mlir::initGetRegisteredOpsBindings(py::module& m) {
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring); m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
} }

View File

@ -19,7 +19,7 @@
namespace torch_mlir { namespace torch_mlir {
void initGetRegisteredOpsBindings(py::module &m); void initGetRegisteredOpsBindings(py::module& m);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -14,11 +14,13 @@ namespace py = pybind11;
using namespace torch_mlir; using namespace torch_mlir;
void torch_mlir::initImportOptionsBindings(py::module &m) { void torch_mlir::initImportOptionsBindings(py::module& m) {
py::class_<ImportOptions>(m, "ImportOptions") py::class_<ImportOptions>(m, "ImportOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("assumeTensorsHaveValueSemantics", .def_readwrite(
&ImportOptions::assumeTensorsHaveValueSemantics) "assumeTensorsHaveValueSemantics",
.def_readwrite("ignoreExistingTensorShapesAndDtypes", &ImportOptions::assumeTensorsHaveValueSemantics)
&ImportOptions::ignoreExistingTensorShapesAndDtypes); .def_readwrite(
"ignoreExistingTensorShapesAndDtypes",
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
} }

View File

@ -13,7 +13,7 @@
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
namespace torch_mlir { namespace torch_mlir {
void initImportOptionsBindings(pybind11::module &m); void initImportOptionsBindings(pybind11::module& m);
} // namespace torch_mlir } // namespace torch_mlir
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H #endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H

View File

@ -49,10 +49,10 @@ using namespace torch_mlir;
// throw an error on). // throw an error on).
namespace { namespace {
struct IValueHasher { struct IValueHasher {
size_t operator()(const c10::IValue &ivalue) const { size_t operator()(const c10::IValue& ivalue) const {
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
return std::hash<const void *>()( return std::hash<const void*>()(
static_cast<const void *>(ivalue.internalToPointer())); static_cast<const void*>(ivalue.internalToPointer()));
} }
return c10::IValue::hash(ivalue); return c10::IValue::hash(ivalue);
@ -65,7 +65,7 @@ struct IValueHasher {
// such as when tracing). Can we do better? // such as when tracing). Can we do better?
namespace { namespace {
struct IValueEq { struct IValueEq {
bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const { bool operator()(const c10::IValue& lhs, const c10::IValue& rhs) const {
return lhs.isSameIdentity(rhs); return lhs.isSameIdentity(rhs);
} }
}; };
@ -99,8 +99,9 @@ namespace {
/// (PyTorch allows this!). /// (PyTorch allows this!).
class IValueImporter { class IValueImporter {
public: public:
IValueImporter(MlirBlock importBlock, MlirContext context, IValueImporter(
ClassAnnotator &annotator, const ImportOptions &importOptions) MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator,
const ImportOptions& importOptions)
: importBlock(importBlock), context(context), annotator(annotator), : importBlock(importBlock), context(context), annotator(annotator),
importOptions(importOptions) {} importOptions(importOptions) {}
@ -110,15 +111,16 @@ private:
MlirValue rawImportIValue(c10::IValue ivalue); MlirValue rawImportIValue(c10::IValue ivalue);
MlirValue importTensor(c10::IValue ivalue); MlirValue importTensor(c10::IValue ivalue);
MlirValue importModule(torch::jit::Module jitModule); MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody, void importMethod(
const MethodAnnotation &methodAnnotation); torch::jit::Function* function, MlirBlock classTypeBody,
void importClassType(c10::ClassType *classType); const MethodAnnotation& methodAnnotation);
void importCompilationUnit(torch::jit::CompilationUnit *cu); void importClassType(c10::ClassType* classType);
void importCompilationUnit(torch::jit::CompilationUnit* cu);
MlirBlock importBlock; MlirBlock importBlock;
MlirContext context; MlirContext context;
ClassAnnotator &annotator; ClassAnnotator& annotator;
const ImportOptions &importOptions; const ImportOptions& importOptions;
// Map tracking already-imported values. // Map tracking already-imported values.
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap; std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
@ -129,16 +131,16 @@ private:
// e.g. methods (the function names are meaningful and match with Python's // e.g. methods (the function names are meaningful and match with Python's
// module hierarchy, with the exception of `__main__` being replaced with // module hierarchy, with the exception of `__main__` being replaced with
// `__torch__`). // `__torch__`).
torch::jit::CompilationUnit *compilationUnit = nullptr; torch::jit::CompilationUnit* compilationUnit = nullptr;
// Used to detect potentially aliasing tensors. // Used to detect potentially aliasing tensors.
std::unordered_set<c10::StorageImpl *> seenStorageImpls; std::unordered_set<c10::StorageImpl*> seenStorageImpls;
// The set of ClassType's that have already been imported. // The set of ClassType's that have already been imported.
// //
// ClassType's are referenced via their `classType->name()->qualifiedName()` // ClassType's are referenced via their `classType->name()->qualifiedName()`
// string (as an MLIR symbol name) so we don't need to keep a map associating // string (as an MLIR symbol name) so we don't need to keep a map associating
// them with the MlirOperation that they import into. // them with the MlirOperation that they import into.
std::unordered_set<c10::ClassType *> classTypes; std::unordered_set<c10::ClassType*> classTypes;
// The stack of attribute names we have traversed to reach the current IValue. // The stack of attribute names we have traversed to reach the current IValue.
// Used for diagnostics. // Used for diagnostics.
std::vector<std::string> attributeNameStack; std::vector<std::string> attributeNameStack;
@ -190,7 +192,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate()); mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); mlirRegionAppendOwnedBlock(
nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
InserterGuard inserterGuard(importBlock, nnModule); InserterGuard inserterGuard(importBlock, nnModule);
@ -198,13 +201,14 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
rootModuleName = moduleTypeName; rootModuleName = moduleTypeName;
} }
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots(); const std::vector<c10::IValue>& slots = currentModule._ivalue()->slots();
const std::vector<c10::ClassAttribute> &classAttributes = const std::vector<c10::ClassAttribute>& classAttributes =
currentModule.type()->getAttributes(); currentModule.type()->getAttributes();
assert(slots.size() == classAttributes.size() && assert(
"mismatch between object and type!"); slots.size() == classAttributes.size() &&
"mismatch between object and type!");
for (int i = 0, e = slots.size(); i < e; i++) { for (int i = 0, e = slots.size(); i < e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i]; const c10::ClassAttribute& classAttribute = classAttributes[i];
attributeNameStack.push_back(classAttribute.getName()); attributeNameStack.push_back(classAttribute.getName());
MlirValue slotValue = importIValue(slots[i]); MlirValue slotValue = importIValue(slots[i]);
// TODO: Is it necessary to track whether an attribute is a "parameter"? // TODO: Is it necessary to track whether an attribute is a "parameter"?
@ -231,7 +235,7 @@ MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
} }
// Reject potentially aliased tensors. // Reject potentially aliased tensors.
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
c10::StorageImpl *storageImpl = c10::StorageImpl* storageImpl =
ivalue.toTensor().storage().unsafeGetStorageImpl(); ivalue.toTensor().storage().unsafeGetStorageImpl();
if (!seenStorageImpls.insert(storageImpl).second) { if (!seenStorageImpls.insert(storageImpl).second) {
std::stringstream msg; std::stringstream msg;
@ -257,8 +261,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirType type = torchMlirTorchBoolTypeGet(context); MlirType type = torchMlirTorchBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.bool", loc, type, importBlock, "torch.constant.bool", loc, type,
toMlirNamedAttribute("value", toMlirNamedAttribute(
mlirBoolAttrGet(context, ivalue.toBool()))); "value", mlirBoolAttrGet(context, ivalue.toBool())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isDouble()) { if (ivalue.isDouble()) {
@ -266,23 +270,23 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.float", loc, type, importBlock, "torch.constant.float", loc, type,
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), "value", mlirFloatAttrDoubleGet(
ivalue.toDouble()))); context, mlirF64TypeGet(context), ivalue.toDouble())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isInt()) { if (ivalue.isInt()) {
MlirType type = torchMlirTorchIntTypeGet(context); MlirType type = torchMlirTorchIntTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.int", loc, type, importBlock, "torch.constant.int", loc, type,
toMlirNamedAttribute("value", toMlirNamedAttribute(
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), "value", mlirIntegerAttrGet(
ivalue.toInt()))); mlirIntegerTypeGet(context, 64), ivalue.toInt())));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isList()) { if (ivalue.isList()) {
c10::List<c10::IValue> list = ivalue.toList(); c10::List<c10::IValue> list = ivalue.toList();
std::vector<MlirValue> elems; std::vector<MlirValue> elems;
for (const c10::IValue &elem : list) { for (const c10::IValue& elem : list) {
elems.push_back(importIValue(elem)); elems.push_back(importIValue(elem));
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
@ -312,7 +316,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
auto list = ivalue.toTuple()->elements(); auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> operands; std::vector<MlirValue> operands;
std::vector<MlirType> types; std::vector<MlirType> types;
for (const c10::IValue &elem : list) { for (const c10::IValue& elem : list) {
MlirValue operand = importIValue(elem); MlirValue operand = importIValue(elem);
operands.push_back(operand); operands.push_back(operand);
types.push_back(mlirValueGetType(operand)); types.push_back(mlirValueGetType(operand));
@ -335,14 +339,14 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
torchMlirTorchStringTypeGet(context), torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirStringAttrGet(context, mlirStringAttrGet(
toMlirStringRef(ivalue.toString()->string())))); context, toMlirStringRef(ivalue.toString()->string()))));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isNone()) { if (ivalue.isNone()) {
MlirOperation operation = MlirOperation operation = createMlirOperationAtEnd(
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc, importBlock, "torch.constant.none", loc,
torchMlirTorchNoneTypeGet(context)); torchMlirTorchNoneTypeGet(context));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isCustomClass()) { if (ivalue.isCustomClass()) {
@ -436,12 +440,12 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
return tensorValue; return tensorValue;
} }
void IValueImporter::importMethod(torch::jit::Function *function, void IValueImporter::importMethod(
MlirBlock classTypeBody, torch::jit::Function* function, MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation) { const MethodAnnotation& methodAnnotation) {
// The function's name becomes the MLIR symbol table name of the imported func // The function's name becomes the MLIR symbol table name of the imported func
// when we import the compilation unit. // when we import the compilation unit.
const std::string &symName = function->qualname().qualifiedName(); const std::string& symName = function->qualname().qualifiedName();
MlirAttribute functionSymbolRef = MlirAttribute functionSymbolRef =
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)); mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
@ -457,7 +461,7 @@ void IValueImporter::importMethod(torch::jit::Function *function,
toMlirNamedAttribute("function", functionSymbolRef), isPrivate); toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
} }
void IValueImporter::importClassType(c10::ClassType *classType) { void IValueImporter::importClassType(c10::ClassType* classType) {
if (!classTypes.insert(classType).second) { if (!classTypes.insert(classType).second) {
return; return;
} }
@ -475,13 +479,13 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr)); mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
ClassAnnotation &classAnnotation = ClassAnnotation& classAnnotation =
annotator.getOrCreateClassAnnotation(classType); annotator.getOrCreateClassAnnotation(classType);
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations(); const auto& attributeAnnotations = classAnnotation.getAttributeAnnotations();
const auto &classAttributes = classType->getAttributes(); const auto& classAttributes = classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) { for (int i = 0, e = classAttributes.size(); i != e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i]; const c10::ClassAttribute& classAttribute = classAttributes[i];
c10::optional<MlirNamedAttribute> isPrivate; c10::optional<MlirNamedAttribute> isPrivate;
if (!attributeAnnotations[i].isExported) { if (!attributeAnnotations[i].isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context)); isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
@ -491,13 +495,14 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
toMlirNamedAttribute( toMlirNamedAttribute(
"name", mlirStringAttrGet( "name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))), context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType( toMlirNamedAttribute(
loc, classAttribute.getType(), importOptions))), "type", mlirTypeAttrGet(getMlirTypeFromTorchType(
loc, classAttribute.getType(), importOptions))),
isPrivate); isPrivate);
} }
const auto &methodAnnotations = classAnnotation.getMethodAnnotations(); const auto& methodAnnotations = classAnnotation.getMethodAnnotations();
const auto &methods = classType->methods(); const auto& methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) { for (int i = 0, e = methods.size(); i != e; i++) {
importMethod(methods[i], classTypeBody, methodAnnotations[i]); importMethod(methods[i], classTypeBody, methodAnnotations[i]);
} }
@ -505,7 +510,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
} }
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
if (compilationUnit == nullptr) { if (compilationUnit == nullptr) {
compilationUnit = cu; compilationUnit = cu;
} else { } else {
@ -524,14 +529,14 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
return; return;
} }
for (torch::jit::Function *function : cu->get_functions()) { for (torch::jit::Function* function : cu->get_functions()) {
// Useful for debugging errors in free functions that end up being // Useful for debugging errors in free functions that end up being
// unused. These can be missing when round-tripping through the on-disk // unused. These can be missing when round-tripping through the on-disk
// format, even though they still cause import issues when importing // format, even though they still cause import issues when importing
// through the larger Python session where they originate. // through the larger Python session where they originate.
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n"; // std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
// std::cerr << *torch::jit::toGraphFunction(function).graph(); // std::cerr << *torch::jit::toGraphFunction(function).graph();
MethodAnnotation *annotation = MethodAnnotation* annotation =
annotator.getMethodAnnotationForFunction(function); annotator.getMethodAnnotationForFunction(function);
MlirOperation func = importJitFunctionAsFuncOp( MlirOperation func = importJitFunctionAsFuncOp(
context, function, context, function,
@ -539,9 +544,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
if (!annotation || !annotation->argAnnotations.has_value()) { if (!annotation || !annotation->argAnnotations.has_value()) {
return {nullptr}; return {nullptr};
} }
c10::optional<std::vector<int64_t>> &maybeShape = c10::optional<std::vector<int64_t>>& maybeShape =
annotation->argAnnotations.value()[argIndex].shape; annotation->argAnnotations.value()[argIndex].shape;
c10::optional<c10::ScalarType> &maybeDtype = c10::optional<c10::ScalarType>& maybeDtype =
annotation->argAnnotations.value()[argIndex].dtype; annotation->argAnnotations.value()[argIndex].dtype;
bool hasValueSemantics = bool hasValueSemantics =
annotation->argAnnotations.value()[argIndex].hasValueSemantics; annotation->argAnnotations.value()[argIndex].hasValueSemantics;
@ -561,10 +566,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
// the C API constructor, when we want the "we know we have 0 sizes" // the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer. // case. So use a dummy data pointer.
int64_t dummy; int64_t dummy;
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data(); int64_t* shapeData = shape.size() == 0 ? &dummy : shape.data();
if (hasValueSemantics) { if (hasValueSemantics) {
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(), typeBound = torchMlirTorchValueTensorTypeGet(
shapeData, dtype); context, shape.size(), shapeData, dtype);
} else { } else {
typeBound = torchMlirTorchNonValueTensorTypeGet( typeBound = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shapeData, dtype); context, shape.size(), shapeData, dtype);
@ -592,10 +597,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
} }
} }
MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, MlirValue torch_mlir::importIValue(
MlirContext context, c10::IValue ivalue, MlirBlock block, MlirContext context,
ClassAnnotator &annotator, ClassAnnotator& annotator, const ImportOptions& importOptions) {
const ImportOptions &importOptions) {
// When debugging module importing, it can be useful to dump as so: // When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule()) // if (ivalue.isModule())
// ivalue.toModule().dump(true, false, false); // ivalue.toModule().dump(true, false, false);

View File

@ -25,9 +25,9 @@ namespace torch_mlir {
/// Main entry-point for importing torch IValue's . /// Main entry-point for importing torch IValue's .
/// Recursively imports `ivalue`, inserting operations at the end of `block`. /// Recursively imports `ivalue`, inserting operations at the end of `block`.
MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, MlirValue importIValue(
ClassAnnotator &annotator, c10::IValue ivalue, MlirBlock block, MlirContext context,
const ImportOptions &importOptions); ClassAnnotator& annotator, const ImportOptions& importOptions);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -22,92 +22,92 @@
namespace torch_mlir { namespace torch_mlir {
inline MlirStringRef toMlirStringRef(const std::string &s) { inline MlirStringRef toMlirStringRef(const std::string& s) {
return mlirStringRefCreate(s.data(), s.size()); return mlirStringRefCreate(s.data(), s.size());
} }
inline MlirStringRef toMlirStringRef(const char *s) { inline MlirStringRef toMlirStringRef(const char* s) {
return mlirStringRefCreate(s, std::strlen(s)); return mlirStringRefCreate(s, std::strlen(s));
} }
inline MlirNamedAttribute toMlirNamedAttribute(const char *s, inline MlirNamedAttribute
MlirAttribute attr) { toMlirNamedAttribute(const char* s, MlirAttribute attr) {
MlirContext context = mlirAttributeGetContext(attr); MlirContext context = mlirAttributeGetContext(attr);
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
return mlirNamedAttributeGet(ident, attr); return mlirNamedAttributeGet(ident, attr);
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void addToMlirOperationState(
MlirNamedAttribute namedAttr) { MlirOperationState& state, MlirNamedAttribute namedAttr) {
mlirOperationStateAddAttributes(&state, 1, &namedAttr); mlirOperationStateAddAttributes(&state, 1, &namedAttr);
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void
MlirRegion region) { addToMlirOperationState(MlirOperationState& state, MlirRegion region) {
mlirOperationStateAddOwnedRegions(&state, 1, &region); mlirOperationStateAddOwnedRegions(&state, 1, &region);
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void
MlirValue value) { addToMlirOperationState(MlirOperationState& state, MlirValue value) {
mlirOperationStateAddOperands(&state, 1, &value); mlirOperationStateAddOperands(&state, 1, &value);
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void addToMlirOperationState(
const std::vector<MlirValue> &values) { MlirOperationState& state, const std::vector<MlirValue>& values) {
mlirOperationStateAddOperands(&state, values.size(), values.data()); mlirOperationStateAddOperands(&state, values.size(), values.data());
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void addToMlirOperationState(
c10::ArrayRef<MlirValue> values) { MlirOperationState& state, c10::ArrayRef<MlirValue> values) {
mlirOperationStateAddOperands(&state, values.size(), values.data()); mlirOperationStateAddOperands(&state, values.size(), values.data());
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void
MlirType resultType) { addToMlirOperationState(MlirOperationState& state, MlirType resultType) {
mlirOperationStateAddResults(&state, 1, &resultType); mlirOperationStateAddResults(&state, 1, &resultType);
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void addToMlirOperationState(
const std::vector<MlirType> &resultTypes) { MlirOperationState& state, const std::vector<MlirType>& resultTypes) {
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
} }
inline void addToMlirOperationState(MlirOperationState &state, inline void addToMlirOperationState(
c10::ArrayRef<MlirType> resultTypes) { MlirOperationState& state, c10::ArrayRef<MlirType> resultTypes) {
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
} }
template <typename T> template <typename T>
void addToMlirOperationState(MlirOperationState &state, c10::optional<T> o) { void addToMlirOperationState(MlirOperationState& state, c10::optional<T> o) {
if (o.has_value()) { if (o.has_value()) {
addToMlirOperationState(state, o.value()); addToMlirOperationState(state, o.value());
} }
} }
inline void addToMlirOperationState(MlirOperationState &state) {} inline void addToMlirOperationState(MlirOperationState& state) {}
template <typename T, typename U, typename... Ts> template <typename T, typename U, typename... Ts>
void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, void addToMlirOperationState(
Ts &&...ts) { MlirOperationState& state, T&& t, U&& u, Ts&&... ts) {
addToMlirOperationState(state, std::forward<T>(t)); addToMlirOperationState(state, std::forward<T>(t));
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...); addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
} }
template <typename... Ts> template <typename... Ts>
MlirOperation createMlirOperation(std::string name, MlirLocation loc, MlirOperation
Ts &&...ts) { createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) {
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
addToMlirOperationState(state, std::forward<Ts>(ts)...); addToMlirOperationState(state, std::forward<Ts>(ts)...);
return mlirOperationCreate(&state); return mlirOperationCreate(&state);
} }
template <typename... Ts> template <typename... Ts>
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, MlirOperation createMlirOperationAtEnd(
MlirLocation loc, Ts &&...ts) { MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) {
MlirOperation operation = MlirOperation operation =
createMlirOperation(name, loc, std::forward<Ts>(ts)...); createMlirOperation(name, loc, std::forward<Ts>(ts)...);
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), mlirBlockInsertOwnedOperationBefore(
operation); block, mlirBlockGetTerminator(block), operation);
return operation; return operation;
} }

View File

@ -22,7 +22,7 @@
namespace py = pybind11; namespace py = pybind11;
using namespace torch_mlir; using namespace torch_mlir;
static py::object getMlirIrClass(const char *className) { static py::object getMlirIrClass(const char* className) {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className); return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
} }
@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) {
return contextObj; return contextObj;
} }
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) { static MlirContext castPythonObjectToMlirContext(py::object& contextObj) {
assert(!contextObj.is_none() && "context cannot be None"); assert(!contextObj.is_none() && "context cannot be None");
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR); auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr()); MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
std::stringstream ss; std::stringstream ss;
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic)) ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
<< ": "; << ": ";
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) { auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) {
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData); auto* ssp = static_cast<std::stringstream*>(stringCallbackUserData);
ssp->write(s.data, s.length); ssp->write(s.data, s.length);
}; };
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss)); mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void*>(&ss));
// Use pybind11's print: // Use pybind11's print:
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
py::print(ss.str(), py::print(
py::arg("file") = py::module_::import("sys").attr("stderr")); ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr"));
} }
// Register a diagnostic handler that will redirect output to `sys.stderr` // Register a diagnostic handler that will redirect output to `sys.stderr`
@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks. // that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
auto diagnosticHandler = [](MlirDiagnostic diagnostic, auto diagnosticHandler = [](MlirDiagnostic diagnostic,
void *) -> MlirLogicalResult { void*) -> MlirLogicalResult {
printDiagnostic(diagnostic); printDiagnostic(diagnostic);
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) { for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i)); printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
return mlirLogicalResultSuccess(); return mlirLogicalResultSuccess();
}; };
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
context, diagnosticHandler, nullptr, [](void *) { return; }); context, diagnosticHandler, nullptr, [](void*) { return; });
// Ignore the ID. We intend to keep this handler for the entire lifetime // Ignore the ID. We intend to keep this handler for the entire lifetime
// of this context. // of this context.
(void)id; (void)id;
@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
terminator = mlirBlockGetFirstOperation(getBodyBlock()); terminator = mlirBlockGetFirstOperation(getBodyBlock());
} }
torch::jit::StrongFunctionPtr torch::jit::StrongFunctionPtr ModuleBuilder::importFunction(
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function, torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) {
py::object maybeImportOptions) {
ImportOptions importOptions; ImportOptions importOptions;
if (!maybeImportOptions.is_none()) { if (!maybeImportOptions.is_none()) {
importOptions = py::cast<ImportOptions>(maybeImportOptions); importOptions = py::cast<ImportOptions>(maybeImportOptions);
} }
MlirBlock block = getBodyBlock(); MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator; MlirOperation terminator = this->terminator;
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_, MlirOperation func = importJitFunctionAsFuncOp(
[](int) -> MlirAttribute { return {nullptr}; }, importOptions); context, function.function_,
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
mlirBlockInsertOwnedOperationBefore(block, terminator, func); mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function; return function;
} }
void ModuleBuilder::importModule(torch::jit::Module jitModule, void ModuleBuilder::importModule(
py::object maybeClassAnnotator, torch::jit::Module jitModule, py::object maybeClassAnnotator,
py::object maybeImportOptions) { py::object maybeImportOptions) {
ClassAnnotator dummyAnnotator; ClassAnnotator dummyAnnotator;
ClassAnnotator *classAnnotator = &dummyAnnotator; ClassAnnotator* classAnnotator = &dummyAnnotator;
if (!maybeClassAnnotator.is_none()) { if (!maybeClassAnnotator.is_none()) {
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator); classAnnotator = py::cast<ClassAnnotator*>(maybeClassAnnotator);
} }
ImportOptions importOptions; ImportOptions importOptions;
if (!maybeImportOptions.is_none()) { if (!maybeImportOptions.is_none()) {
@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
// precise `torch.class_type` names. // precise `torch.class_type` names.
// //
// This name is not semantically load-bearing!!! // This name is not semantically load-bearing!!!
auto &name = *jitModule.type()->name(); auto& name = *jitModule.type()->name();
auto debugModuleNameAttr = mlirStringAttrGet( auto debugModuleNameAttr = mlirStringAttrGet(
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1])); context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
mlirOperationSetAttributeByName(mlirModuleGetOperation(module), mlirOperationSetAttributeByName(
toMlirStringRef("torch.debug_module_name"), mlirModuleGetOperation(module),
debugModuleNameAttr); toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr);
importIValue(jitModule._ivalue(), mlirModuleGetBody(module), importIValue(
mlirModuleGetContext(module), *classAnnotator, importOptions); jitModule._ivalue(), mlirModuleGetBody(module),
mlirModuleGetContext(module), *classAnnotator, importOptions);
} }
MlirBlock ModuleBuilder::getBodyBlock() { MlirBlock ModuleBuilder::getBodyBlock() {
@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() {
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0)); return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
} }
void ModuleBuilder::bind(py::module &m) { void ModuleBuilder::bind(py::module& m) {
py::class_<ModuleBuilder>(m, "ModuleBuilder") py::class_<ModuleBuilder>(m, "ModuleBuilder")
.def(py::init<py::object>(), py::arg("context") = py::none()) .def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"), .def(
py::arg("importOptions") = py::none()) "import_function", &ModuleBuilder::importFunction,
.def("import_module", &ModuleBuilder::importModule, py::arg("module"), py::arg("function"), py::arg("importOptions") = py::none())
py::arg("classAnnotator") = py::none(), .def(
py::arg("importOptions") = py::none()); "import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none(),
py::arg("importOptions") = py::none());
} }

View File

@ -29,7 +29,7 @@ public:
ModuleBuilder(pybind11::object contextObj); ModuleBuilder(pybind11::object contextObj);
/// Creates Python bindings for the class. /// Creates Python bindings for the class.
static void bind(pybind11::module &m); static void bind(pybind11::module& m);
pybind11::object getContextObj() { return contextObj; } pybind11::object getContextObj() { return contextObj; }
pybind11::object getModuleObj() { return moduleObj; } pybind11::object getModuleObj() { return moduleObj; }
@ -38,16 +38,15 @@ public:
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
// Just a bit of naming cruft. // Just a bit of naming cruft.
// Returns the same function, making it suitable as a nested decorator. // Returns the same function, making it suitable as a nested decorator.
torch::jit::StrongFunctionPtr torch::jit::StrongFunctionPtr importFunction(
importFunction(torch::jit::StrongFunctionPtr function, torch::jit::StrongFunctionPtr function, py::object maybeImportOptions);
py::object maybeImportOptions);
// Imports a torch::jit::Module into the current module, using the // Imports a torch::jit::Module into the current module, using the
// annotations, if not none, provided in `maybeClassAnnotator` which should be // annotations, if not none, provided in `maybeClassAnnotator` which should be
// a ClassAnnotator. // a ClassAnnotator.
void importModule(torch::jit::Module jitModule, void importModule(
py::object maybeClassAnnotator, torch::jit::Module jitModule, py::object maybeClassAnnotator,
py::object maybeImportOptions); py::object maybeImportOptions);
private: private:
MlirBlock getBodyBlock(); MlirBlock getBodyBlock();

View File

@ -33,41 +33,42 @@ class NodeImporter {
public: public:
NodeImporter(MlirContext context) : context(context) {} NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock, void importNode(
const ImportOptions &importOptions = {}); Node* node, MlirBlock appendToBlock,
const ImportOptions& importOptions = {});
MlirBlock importBlock( MlirBlock importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator, Block* jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions &importOptions = {}); const ImportOptions& importOptions = {});
private: private:
MlirBlock MlirBlock createBlockFor(
createBlockFor(Block *jitBlock, Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, const ImportOptions& importOptions = {});
const ImportOptions &importOptions = {}); void mapValue(Value* jitValue, MlirValue value);
void mapValue(Value *jitValue, MlirValue value); void mapResults(Node* node, MlirOperation operation);
void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value* jitValue);
MlirValue lookupMappedValue(Value *jitValue); std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value*> values);
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
MlirContext context; MlirContext context;
std::unordered_map<Value *, MlirValue> valueMap; std::unordered_map<Value*, MlirValue> valueMap;
}; };
} // namespace } // namespace
using InputsTransformFn = using InputsTransformFn =
std::function<std::vector<MlirValue>(std::vector<MlirValue> &)>; std::function<std::vector<MlirValue>(std::vector<MlirValue>&)>;
// The inputs of `DictConstruct` in TorchScript IR are in the order // The inputs of `DictConstruct` in TorchScript IR are in the order
// like k0, v0, k1, v1. Rearrange them to put the key operands together and // like k0, v0, k1, v1. Rearrange them to put the key operands together and
// then the value operands like k0, k1,v0, v1. This is the expected format by // then the value operands like k0, k1,v0, v1. This is the expected format by
// the corresponding MLIR op. // the corresponding MLIR op.
static std::vector<MlirValue> static std::vector<MlirValue>
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) { rearrangeDictConstructInputs(std::vector<MlirValue>& inputs) {
if (inputs.empty()) if (inputs.empty())
return inputs; return inputs;
assert(inputs.size() % 2 == 0 && assert(
"DictConstruct must have even number of operands"); inputs.size() % 2 == 0 &&
"DictConstruct must have even number of operands");
std::vector<MlirValue> rearranged; std::vector<MlirValue> rearranged;
std::vector<MlirValue> values; std::vector<MlirValue> values;
@ -79,12 +80,12 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
return rearranged; return rearranged;
} }
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, void NodeImporter::importNode(
const ImportOptions &importOptions) { Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) {
MlirLocation loc = getMlirLocationFromNode(context, node); MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind(); auto kind = node->kind();
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, auto createAndMapTrivialNode = [&](Node* node, const std::string& opName,
InputsTransformFn t) { InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs()); std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
@ -95,7 +96,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
}; };
auto createAndMapNodeWithAttribute = auto createAndMapNodeWithAttribute =
[&](Node *node, const std::string &opName, const std::string &attrName, [&](Node* node, const std::string& opName, const std::string& attrName,
MlirAttribute attr) { MlirAttribute attr) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc, appendToBlock, opName, loc,
@ -132,27 +133,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
// ListConstruct and DictConstruct too. // ListConstruct and DictConstruct too.
auto containedTypes = c10::fmap( auto containedTypes = c10::fmap(
node->output()->type()->cast<c10::TupleType>()->containedTypes(), node->output()->type()->cast<c10::TupleType>()->containedTypes(),
[&](const c10::TypePtr &t) { [&](const c10::TypePtr& t) {
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
return type; return type;
}); });
createAndMapTrivialNode(node, createAndMapTrivialNode(
"torch.prim." + std::string(kind.toUnqualString()), node, "torch.prim." + std::string(kind.toUnqualString()),
[&](std::vector<MlirValue> &inputs) { [&](std::vector<MlirValue>& inputs) {
assert(containedTypes.size() == inputs.size()); assert(containedTypes.size() == inputs.size());
return adjustStaticInformationForValues( return adjustStaticInformationForValues(
appendToBlock, loc, inputs, containedTypes, appendToBlock, loc, inputs, containedTypes,
/*userAllowsRefinement=*/true); /*userAllowsRefinement=*/true);
}); });
return; return;
} }
case c10::prim::DictConstruct: { case c10::prim::DictConstruct: {
createAndMapTrivialNode(node, createAndMapTrivialNode(
"torch.prim." + std::string(kind.toUnqualString()), node, "torch.prim." + std::string(kind.toUnqualString()),
rearrangeDictConstructInputs); rearrangeDictConstructInputs);
return; return;
} }
case c10::prim::Load: case c10::prim::Load:
@ -170,32 +171,34 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
auto output = node->output(); auto output = node->output();
MlirOperation op; MlirOperation op;
if (output->type()->cast<c10::NoneType>()) { if (output->type()->cast<c10::NoneType>()) {
op = createMlirOperation("torch.constant.none", loc, op = createMlirOperation(
torchMlirTorchNoneTypeGet(context)); "torch.constant.none", loc, torchMlirTorchNoneTypeGet(context));
} else if (output->type()->cast<c10::BoolType>()) { } else if (output->type()->cast<c10::BoolType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context), "torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i( "value",
c10::attr::value))))); mlirBoolAttrGet(
context, static_cast<bool>(node->i(c10::attr::value)))));
} else if (output->type()->cast<c10::IntType>()) { } else if (output->type()->cast<c10::IntType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.int", loc, "torch.constant.int", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute("value", toMlirNamedAttribute(
importAttribute(loc, node, c10::attr::value))); "value", importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::FloatType>()) { } else if (output->type()->cast<c10::FloatType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.float", loc, "torch.constant.float", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute("value", toMlirNamedAttribute(
importAttribute(loc, node, c10::attr::value))); "value", importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) { } else if (output->type()->cast<c10::StringType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context), "torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s( "value",
c10::attr::value))))); mlirStringAttrGet(
context, toMlirStringRef(node->s(c10::attr::value)))));
} else if (output->type()->cast<c10::TensorType>()) { } else if (output->type()->cast<c10::TensorType>()) {
MlirAttribute attr = importAttribute(loc, node, c10::attr::value); MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
if (importOptions.assumeTensorsHaveValueSemantics) { if (importOptions.assumeTensorsHaveValueSemantics) {
@ -214,24 +217,26 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
"torch.constant.device", loc, "torch.constant.device", loc,
getMlirTypeFromTorchType(loc, output->type(), importOptions), getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s( "value",
c10::attr::value))))); mlirStringAttrGet(
context, toMlirStringRef(node->s(c10::attr::value)))));
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) { } else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
torch::jit::Function *function = functionType->function(); torch::jit::Function* function = functionType->function();
const std::string &symName = function->qualname().qualifiedName(); const std::string& symName = function->qualname().qualifiedName();
op = createMlirOperation( op = createMlirOperation(
"func.constant", loc, "func.constant", loc,
getFunctionTypeFromSchema(context, function->getSchema(), getFunctionTypeFromSchema(
importOptions), context, function->getSchema(), importOptions),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else if (output->type()->cast<c10::ListType>() || } else if (
output->type()->cast<c10::TupleType>()) { output->type()->cast<c10::ListType>() ||
output->type()->cast<c10::TupleType>()) {
ClassAnnotator dummyAnnotator; ClassAnnotator dummyAnnotator;
MlirValue listOrTupleValue = MlirValue listOrTupleValue = importIValue(
importIValue(node->ival(c10::attr::value), appendToBlock, context, node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator,
dummyAnnotator, importOptions); importOptions);
mapResults(node, mlirOpResultGetOwner(listOrTupleValue)); mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
return; // Early return, since `importIValue` already added op to block. return; // Early return, since `importIValue` already added op to block.
} else { } else {
@ -259,19 +264,20 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
mapResults(node, operation); mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = { std::vector<MlirType> terminatorOperandTypes = {
torchMlirTorchBoolTypeGet(context)}; torchMlirTorchBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(), terminatorOperandTypes.insert(
resultTypes.begin(), resultTypes.end()); terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues, auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd( createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop.condition", loc, appendToBlock, "torch.prim.Loop.condition", loc,
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, adjustStaticInformationForValues(
terminatorOperandTypes, appendToBlock, loc, yieldedValues, terminatorOperandTypes,
/*userAllowsRefinement=*/false)); /*userAllowsRefinement=*/false));
}; };
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0), mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); importBlock(
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
return; return;
} }
@ -286,25 +292,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
MlirBlock appendToBlock) { MlirBlock appendToBlock) {
createMlirOperationAtEnd( createMlirOperationAtEnd(
appendToBlock, "torch.prim.If.yield", loc, appendToBlock, "torch.prim.If.yield", loc,
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, adjustStaticInformationForValues(
resultTypes, appendToBlock, loc, yieldedValues, resultTypes,
/*userAllowsRefinement=*/false)); /*userAllowsRefinement=*/false));
}; };
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0), mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); importBlock(
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 1), mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions)); importBlock(
node->blocks()[1], createTerminator, c10::nullopt, importOptions));
return; return;
} }
if (kind == c10::prim::CallMethod) { if (kind == c10::prim::CallMethod) {
auto classType = node->input(0)->type()->cast<c10::ClassType>(); auto classType = node->input(0)->type()->cast<c10::ClassType>();
auto methodName = node->s(c10::attr::name); auto methodName = node->s(c10::attr::name);
torch::jit::Function *function = classType->findMethod(methodName); torch::jit::Function* function = classType->findMethod(methodName);
MlirType calleeType = MlirType calleeType = getFunctionTypeFromSchema(
getFunctionTypeFromSchema(context, function->getSchema(), importOptions); context, function->getSchema(), importOptions);
std::vector<MlirType> expectedTypes; std::vector<MlirType> expectedTypes;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
@ -315,17 +323,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
adjustStaticInformationForValues( adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs()), appendToBlock, loc, lookupMappedValues(node->inputs()),
expectedTypes, /*userAllowsRefinement=*/false), expectedTypes, /*userAllowsRefinement=*/false),
toMlirNamedAttribute("name", toMlirNamedAttribute(
importAttribute(loc, node, c10::attr::name))); "name", importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation); mapResults(node, operation);
return; return;
} }
if (kind == c10::prim::CallFunction) { if (kind == c10::prim::CallFunction) {
auto functionType = node->input(0)->type()->cast<c10::FunctionType>(); auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block *calleeEntryBlock = torch::jit::Block* calleeEntryBlock =
torch::jit::toGraphFunction(*functionType->function()).graph()->block(); torch::jit::toGraphFunction(*functionType->function()).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value* v) {
return getMlirTypeFromTorchType(loc, v->type(), importOptions); return getMlirTypeFromTorchType(loc, v->type(), importOptions);
}); });
std::string functionName = node->input(0)->node()->s(c10::attr::name); std::string functionName = node->input(0)->node()->s(c10::attr::name);
@ -340,9 +348,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
// promoted result dtype for a PyTorch computation. Here we turn the call to // promoted result dtype for a PyTorch computation. Here we turn the call to
// this function to the torch dialect equivalent op `torch.promote_dtypes`. // this function to the torch dialect equivalent op `torch.promote_dtypes`.
if (functionName == "__torch_mlir_internal_promote_dtypes") { if (functionName == "__torch_mlir_internal_promote_dtypes") {
operation = operation = createMlirOperationAtEnd(
createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc, appendToBlock, "torch.promote_dtypes", loc, resultTypes,
resultTypes, adjustedFuncArgs); adjustedFuncArgs);
} else { } else {
operation = createMlirOperationAtEnd( operation = createMlirOperationAtEnd(
appendToBlock, "func.call_indirect", loc, resultTypes, appendToBlock, "func.call_indirect", loc, resultTypes,
@ -362,22 +370,22 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
} }
MlirBlock NodeImporter::importBlock( MlirBlock NodeImporter::importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator, Block* jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) { const ImportOptions& importOptions) {
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
for (Node *node : jitBlock->nodes()) { for (Node* node : jitBlock->nodes()) {
importNode(node, block, importOptions); importNode(node, block, importOptions);
} }
Node *returnNode = jitBlock->return_node(); Node* returnNode = jitBlock->return_node();
createTerminator(lookupMappedValues(returnNode->inputs()), block); createTerminator(lookupMappedValues(returnNode->inputs()), block);
return block; return block;
} }
MlirBlock NodeImporter::createBlockFor( MlirBlock NodeImporter::createBlockFor(
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) { const ImportOptions& importOptions) {
Node *paramNode = jitBlock->param_node(); Node* paramNode = jitBlock->param_node();
MlirLocation loc = getMlirLocationFromNode(context, paramNode); MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> paramNodeTypes = std::vector<MlirType> paramNodeTypes =
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
@ -386,11 +394,11 @@ MlirBlock NodeImporter::createBlockFor(
else else
assert(blockArgTypes->size() == paramNodeTypes.size()); assert(blockArgTypes->size() == paramNodeTypes.size());
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc); std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
MlirBlock block = MlirBlock block = mlirBlockCreate(
mlirBlockCreate(blockArgTypes.value().size(), blockArgTypes.value().size(), blockArgTypes.value().data(),
blockArgTypes.value().data(), blockArgLocs.data()); blockArgLocs.data());
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value *jitValue = paramNode->outputs()[i]; Value* jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i); MlirValue value = mlirBlockGetArgument(block, i);
MlirValue adjusted = adjustStaticInformationForValues( MlirValue adjusted = adjustStaticInformationForValues(
block, loc, {value}, {paramNodeTypes[i]}, block, loc, {value}, {paramNodeTypes[i]},
@ -400,39 +408,40 @@ MlirBlock NodeImporter::createBlockFor(
return block; return block;
} }
void NodeImporter::mapValue(Value *jitValue, MlirValue value) { void NodeImporter::mapValue(Value* jitValue, MlirValue value) {
auto it = valueMap.find(jitValue); auto it = valueMap.find(jitValue);
(void)it; (void)it;
assert(it == valueMap.end() && "jitValue has already been mapped"); assert(it == valueMap.end() && "jitValue has already been mapped");
valueMap[jitValue] = value; valueMap[jitValue] = value;
} }
void NodeImporter::mapResults(Node *node, MlirOperation operation) { void NodeImporter::mapResults(Node* node, MlirOperation operation) {
assert(node->outputs().size() == assert(
(size_t)mlirOperationGetNumResults(operation)); node->outputs().size() == (size_t)mlirOperationGetNumResults(operation));
for (int i = 0, e = node->outputs().size(); i < e; i++) { for (int i = 0, e = node->outputs().size(); i < e; i++) {
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i)); mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
} }
} }
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) { MlirValue NodeImporter::lookupMappedValue(Value* jitValue) {
auto it = valueMap.find(jitValue); auto it = valueMap.find(jitValue);
assert(it != valueMap.end() && assert(
"trying to get mapping for jitValue that is not mapped yet!"); it != valueMap.end() &&
"trying to get mapping for jitValue that is not mapped yet!");
return it->second; return it->second;
} }
std::vector<MlirValue> std::vector<MlirValue>
NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) { NodeImporter::lookupMappedValues(c10::ArrayRef<Value*> values) {
std::vector<MlirValue> ret; std::vector<MlirValue> ret;
for (Value *value : values) { for (Value* value : values) {
ret.push_back(lookupMappedValue(value)); ret.push_back(lookupMappedValue(value));
} }
return ret; return ret;
} }
MlirBlock MlirBlock torch_mlir::importBlock(
torch_mlir::importBlock(MlirContext context, Block *jitBlock, MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator,
CreateTerminatorFn createTerminator, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes, const ImportOptions& importOptions) {
const ImportOptions &importOptions) {
NodeImporter importer(context); NodeImporter importer(context);
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions); return importer.importBlock(
jitBlock, createTerminator, blockArgTypes, importOptions);
} }

View File

@ -37,10 +37,10 @@ using CreateTerminatorFn =
/// adjust the types to the block argument types. /// adjust the types to the block argument types.
/// TODO: Formalize what type conversions are allowed here. /// TODO: Formalize what type conversions are allowed here.
MlirBlock importBlock( MlirBlock importBlock(
MlirContext context, torch::jit::Block *jitBlock, MlirContext context, torch::jit::Block* jitBlock,
CreateTerminatorFn createTerminator, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions &importOptions = {}); const ImportOptions& importOptions = {});
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -26,8 +26,8 @@
using namespace torch_mlir; using namespace torch_mlir;
static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, static MlirType getMlirTypeForTorchScalarTypeRaw(
c10::ScalarType scalarType) { MlirContext context, c10::ScalarType scalarType) {
using c10::ScalarType; using c10::ScalarType;
switch (scalarType) { switch (scalarType) {
case ScalarType::Byte: case ScalarType::Byte:
@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
} }
} }
MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc, MlirType torch_mlir::getMlirTypeForTorchScalarType(
c10::ScalarType scalarType) { MlirLocation loc, c10::ScalarType scalarType) {
auto type = auto type =
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType); getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
// There is no generic way to import custom classes (or their types), so we // There is no generic way to import custom classes (or their types), so we
// have to name match them here (and the relevant code in the ivalue // have to name match them here (and the relevant code in the ivalue
// importer) and create special IR constructs for them. // importer) and create special IR constructs for them.
static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, static MlirType mapCustomClassType(
const c10::ClassTypePtr &classType) { MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) {
// If the type is unnamed, it cannot be a custom class. // If the type is unnamed, it cannot be a custom class.
if (!classType->name().has_value()) { if (!classType->name().has_value()) {
return {nullptr}; return {nullptr};
@ -126,10 +126,9 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
MlirType MlirType torch_mlir::getMlirTypeFromTorchType(
torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, MlirLocation loc, const c10::TypePtr& torchType,
const c10::TypePtr &torchType, const ImportOptions& importOptions) {
const ImportOptions &importOptions) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
using c10::TypeKind; using c10::TypeKind;
auto kind = torchType->kind(); auto kind = torchType->kind();
@ -141,10 +140,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
: torchMlirTorchNonValueTensorTypeGet; : torchMlirTorchNonValueTensorTypeGet;
if (importOptions.ignoreExistingTensorShapesAndDtypes) { if (importOptions.ignoreExistingTensorShapesAndDtypes) {
return getMlirTensorType(context, return getMlirTensorType(
/*numSizes=*/-1, context,
/*optionalSizes=*/nullptr, /*numSizes=*/-1,
/*optionalDtype=*/{nullptr}); /*optionalSizes=*/nullptr,
/*optionalDtype=*/{nullptr});
} }
// Element type. // Element type.
@ -156,17 +156,18 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
return {nullptr}; return {nullptr};
} }
// Sizes. // Sizes.
auto &sizes = tensorType->symbolic_sizes(); auto& sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) { if (!sizes.rank()) {
// Unranked. // Unranked.
return getMlirTensorType(context, return getMlirTensorType(
/*numSizes=*/-1, context,
/*optionalSizes=*/nullptr, /*numSizes=*/-1,
/*optionalDtype=*/ /*optionalSizes=*/nullptr,
elementType); /*optionalDtype=*/
elementType);
} }
// Ranked with possibly dynamic dims. // Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes(); auto& symbolicShape = tensorType->symbolic_sizes();
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.resize(*sizes.rank()); dims.resize(*sizes.rank());
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
@ -179,11 +180,12 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
// the C API constructor, when we want the "we know we have 0 sizes" // the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer. // case. So use a dummy data pointer.
int64_t dummy; int64_t dummy;
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data(); int64_t* dimsData = dims.size() == 0 ? &dummy : dims.data();
return getMlirTensorType(context, dims.size(), return getMlirTensorType(
/*optionalSizes=*/dimsData, context, dims.size(),
/*optionalDtype=*/ /*optionalSizes=*/dimsData,
elementType); /*optionalDtype=*/
elementType);
} }
case TypeKind::IntType: { case TypeKind::IntType: {
return torchMlirTorchIntTypeGet(context); return torchMlirTorchIntTypeGet(context);
@ -207,22 +209,22 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
} }
case TypeKind::TupleType: { case TypeKind::TupleType: {
std::vector<MlirType> containedTypes; std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type : for (const c10::TypePtr& type :
torchType->cast<c10::TupleType>()->containedTypes()) { torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back( containedTypes.push_back(
getMlirTypeFromTorchType(loc, type, importOptions)); getMlirTypeFromTorchType(loc, type, importOptions));
} }
return torchMlirTorchTupleTypeGet(context, containedTypes.size(), return torchMlirTorchTupleTypeGet(
containedTypes.data()); context, containedTypes.size(), containedTypes.data());
} }
case TypeKind::UnionType: { case TypeKind::UnionType: {
std::vector<MlirType> containedTypes; std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type : for (const c10::TypePtr& type :
torchType->cast<c10::UnionType>()->containedTypes()) { torchType->cast<c10::UnionType>()->containedTypes()) {
containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
} }
return torchMlirTorchUnionTypeGet(context, containedTypes.size(), return torchMlirTorchUnionTypeGet(
containedTypes.data()); context, containedTypes.size(), containedTypes.data());
} }
case TypeKind::ListType: { case TypeKind::ListType: {
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
@ -242,7 +244,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
return torchMlirTorchAnyTypeGet(context); return torchMlirTorchAnyTypeGet(context);
} }
case TypeKind::ClassType: { case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>(); const c10::ClassTypePtr& classType = torchType->cast<c10::ClassType>();
MlirType customClassType = mapCustomClassType(context, loc, classType); MlirType customClassType = mapCustomClassType(context, loc, classType);
if (!mlirTypeIsNull(customClassType)) { if (!mlirTypeIsNull(customClassType)) {
return customClassType; return customClassType;
@ -266,12 +268,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
} }
} }
MlirType MlirType torch_mlir::getFunctionTypeFromSchema(
torch_mlir::getFunctionTypeFromSchema(MlirContext context, MlirContext context, const c10::FunctionSchema& schema,
const c10::FunctionSchema &schema, const ImportOptions& importOptions) {
const ImportOptions &importOptions) {
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
auto mapType = [&](const c10::TypePtr &torchType) { auto mapType = [&](const c10::TypePtr& torchType) {
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
if (mlirTypeIsNull(type)) { if (mlirTypeIsNull(type)) {
std::stringstream msg; std::stringstream msg;
@ -283,17 +284,20 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
}; };
std::vector<MlirType> inputTypes = std::vector<MlirType> inputTypes =
c10::fmap(schema.arguments(), c10::fmap(schema.arguments(), [&](const c10::Argument& arg) {
[&](const c10::Argument &arg) { return mapType(arg.type()); }); return mapType(arg.type());
});
std::vector<MlirType> outputTypes = std::vector<MlirType> outputTypes =
c10::fmap(schema.returns(), c10::fmap(schema.returns(), [&](const c10::Argument& arg) {
[&](const c10::Argument &arg) { return mapType(arg.type()); }); return mapType(arg.type());
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), });
outputTypes.size(), outputTypes.data()); return mlirFunctionTypeGet(
context, inputTypes.size(), inputTypes.data(), outputTypes.size(),
outputTypes.data());
} }
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
MlirLocation loc) { at::Tensor tensor, MlirLocation loc) {
using at::ScalarType; using at::ScalarType;
auto throwUnsupportedTensorError = [&]() { auto throwUnsupportedTensorError = [&]() {
@ -308,8 +312,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
// The flat number of bytes throws an exception for tensors that are not // The flat number of bytes throws an exception for tensors that are not
// dense and accessible as such. // dense and accessible as such.
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor, at::checkLayout(
c10::Layout::Strided); at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided);
// Construct the ShapedType. // Construct the ShapedType.
@ -334,47 +338,47 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
switch (tensor.scalar_type()) { switch (tensor.scalar_type()) {
case ScalarType::Int: case ScalarType::Int:
return mlirDenseElementsAttrInt32Get( return mlirDenseElementsAttrInt32Get(
shapedType, numElements, static_cast<const int32_t *>(tensorData)); shapedType, numElements, static_cast<const int32_t*>(tensorData));
break; break;
case ScalarType::Long: case ScalarType::Long:
return mlirDenseElementsAttrInt64Get( return mlirDenseElementsAttrInt64Get(
shapedType, numElements, static_cast<const int64_t *>(tensorData)); shapedType, numElements, static_cast<const int64_t*>(tensorData));
break; break;
case ScalarType::Float: case ScalarType::Float:
return mlirDenseElementsAttrFloatGet( return mlirDenseElementsAttrFloatGet(
shapedType, numElements, static_cast<const float *>(tensorData)); shapedType, numElements, static_cast<const float*>(tensorData));
break; break;
case ScalarType::Double: case ScalarType::Double:
return mlirDenseElementsAttrDoubleGet( return mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double *>(tensorData)); shapedType, numElements, static_cast<const double*>(tensorData));
break; break;
case ScalarType::Bool: { case ScalarType::Bool: {
// TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed // TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed
// upstream to take in a `const bool *` rather than a `const int *` to avoid // upstream to take in a `const bool *` rather than a `const int *` to avoid
// the unnecessary copying into an array four times as large. // the unnecessary copying into an array four times as large.
const int8_t *elements = static_cast<const int8_t *>(tensorData); const int8_t* elements = static_cast<const int8_t*>(tensorData);
std::vector<int> tensorDataVector(elements, elements + numElements); std::vector<int> tensorDataVector(elements, elements + numElements);
return mlirDenseElementsAttrBoolGet(shapedType, numElements, return mlirDenseElementsAttrBoolGet(
tensorDataVector.data()); shapedType, numElements, tensorDataVector.data());
} break; } break;
case ScalarType::QInt8: case ScalarType::QInt8:
return mlirDenseElementsAttrInt8Get( return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t *>(tensorData)); shapedType, numElements, static_cast<const int8_t*>(tensorData));
case ScalarType::QUInt8: case ScalarType::QUInt8:
return mlirDenseElementsAttrUInt8Get( return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t *>(tensorData)); shapedType, numElements, static_cast<const uint8_t*>(tensorData));
case ScalarType::BFloat16: case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get( return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData)); shapedType, numElements, static_cast<const uint16_t*>(tensorData));
case ScalarType::Half: case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get( return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData)); shapedType, numElements, static_cast<const uint16_t*>(tensorData));
case ScalarType::Byte: case ScalarType::Byte:
return mlirDenseElementsAttrUInt8Get( return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t *>(tensorData)); shapedType, numElements, static_cast<const uint8_t*>(tensorData));
case ScalarType::Char: case ScalarType::Char:
return mlirDenseElementsAttrInt8Get( return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t *>(tensorData)); shapedType, numElements, static_cast<const int8_t*>(tensorData));
default: default:
throwUnsupportedTensorError(); throwUnsupportedTensorError();
@ -382,9 +386,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
return {nullptr}; // Unreachable. return {nullptr}; // Unreachable.
} }
MlirAttribute torch_mlir::importAttribute(MlirLocation loc, MlirAttribute torch_mlir::importAttribute(
torch::jit::Node *node, MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) {
c10::Symbol symbol) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
auto kind = node->kindOf(symbol); auto kind = node->kindOf(symbol);
switch (kind) { switch (kind) {
@ -393,8 +396,8 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
// do that. // do that.
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol)); return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
case torch::jit::AttributeKind::f: case torch::jit::AttributeKind::f:
return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context), return mlirFloatAttrDoubleGet(
node->f(symbol)); context, mlirF64TypeGet(context), node->f(symbol));
case torch::jit::AttributeKind::s: case torch::jit::AttributeKind::s:
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol))); return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
case torch::jit::AttributeKind::t: case torch::jit::AttributeKind::t:
@ -408,23 +411,23 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
} }
} }
MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, MlirLocation torch_mlir::getMlirLocationFromNode(
torch::jit::Node *node) { MlirContext context, torch::jit::Node* node) {
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
if (node->hasAttribute(c10::Symbol::attr("source_files"))) { if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files")); const auto& sourceFiles = node->ss(c10::Symbol::attr("source_files"));
const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers")); const auto& lineNumbers = node->is(c10::Symbol::attr("line_numbers"));
const auto &functions = node->ss(c10::Symbol::attr("functions")); const auto& functions = node->ss(c10::Symbol::attr("functions"));
// Chain a sequence of calls to construct single MlirLocation. // Chain a sequence of calls to construct single MlirLocation.
for (const auto i : c10::irange(sourceFiles.size())) { for (const auto i : c10::irange(sourceFiles.size())) {
MlirLocation newLoc = mlirLocationNameGet( MlirLocation newLoc = mlirLocationNameGet(
context, toMlirStringRef(functions[i]), context, toMlirStringRef(functions[i]),
mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]), mlirLocationFileLineColGet(
lineNumbers[i], context, toMlirStringRef(sourceFiles[i]), lineNumbers[i],
0 /* column is not available */ 0 /* column is not available */
)); ));
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc)); loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
} }
if (sourceFiles.size() == 1) { if (sourceFiles.size() == 1) {
@ -433,7 +436,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context)); loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context));
} }
} else if (auto flc = node->sourceRange().file_line_col()) { } else if (auto flc = node->sourceRange().file_line_col()) {
const std::string &file = std::get<0>(*flc); const std::string& file = std::get<0>(*flc);
int line = std::get<1>(*flc); int line = std::get<1>(*flc);
int col = std::get<2>(*flc); int col = std::get<2>(*flc);
loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col); loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col);
@ -445,7 +448,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
locationName = scopeName; locationName = scopeName;
} }
if (const c10::FunctionSchema *schema = node->maybeSchema()) { if (const c10::FunctionSchema* schema = node->maybeSchema()) {
if (!locationName.empty()) { if (!locationName.empty()) {
locationName += "/"; locationName += "/";
} }
@ -459,10 +462,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
return loc; return loc;
} }
std::vector<MlirType> std::vector<MlirType> torch_mlir::getMlirTypesFromValues(
torch_mlir::getMlirTypesFromValues(MlirLocation loc, MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
c10::ArrayRef<torch::jit::Value *> values, const ImportOptions& importOptions) {
const ImportOptions &importOptions) {
std::vector<MlirType> ret; std::vector<MlirType> ret;
for (auto value : values) { for (auto value : values) {
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
@ -491,25 +493,24 @@ std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
} }
std::stringstream msg; std::stringstream msg;
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { MlirStringCallback printToStream = +[](MlirStringRef str, void* userData) {
std::stringstream *stream = static_cast<std::stringstream *>(userData); std::stringstream* stream = static_cast<std::stringstream*>(userData);
stream->write(str.data, str.length); stream->write(str.data, str.length);
}; };
msg << "unhandled: could not adjust static info for type from "; msg << "unhandled: could not adjust static info for type from ";
mlirTypePrint(type, printToStream, static_cast<void *>(&msg)); mlirTypePrint(type, printToStream, static_cast<void*>(&msg));
msg << " to type "; msg << " to type ";
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg)); mlirTypePrint(expectedType, printToStream, static_cast<void*>(&msg));
mlirEmitError(loc, msg.str().c_str()); mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted(); throw mlir_diagnostic_emitted();
} }
return ret; return ret;
} }
MlirOperation MlirOperation torch_mlir::createOperationFromSchema(
torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc, MlirBlock appendToBlock, MlirLocation loc,
const c10::FunctionSchema &schema, const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
c10::ArrayRef<MlirType> resultTypes, c10::ArrayRef<MlirValue> operands) {
c10::ArrayRef<MlirValue> operands) {
MlirContext context = mlirLocationGetContext(loc); MlirContext context = mlirLocationGetContext(loc);
// Munge the name into the appropriate MLIR operation name. // Munge the name into the appropriate MLIR operation name.
@ -519,15 +520,15 @@ torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc,
auto separatorPosition = opNameSuffix.find_first_of("::"); auto separatorPosition = opNameSuffix.find_first_of("::");
assert(separatorPosition != std::string::npos); assert(separatorPosition != std::string::npos);
opNameSuffix.replace(separatorPosition, 2, "."); opNameSuffix.replace(separatorPosition, 2, ".");
const std::string &overloadName = schema.overload_name(); const std::string& overloadName = schema.overload_name();
if (!overloadName.empty()) { if (!overloadName.empty()) {
opNameSuffix = opNameSuffix + "." + overloadName; opNameSuffix = opNameSuffix + "." + overloadName;
} }
std::string opName = "torch." + opNameSuffix; std::string opName = "torch." + opNameSuffix;
// If we have a registered op, use it! // If we have a registered op, use it!
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) { if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes, return createMlirOperationAtEnd(
operands); appendToBlock, opName, loc, resultTypes, operands);
} }
// Oops, no registered op -- create an opaque wrapper so that import can // Oops, no registered op -- create an opaque wrapper so that import can
// still succeed. This helps a common use case of filling out registered ops // still succeed. This helps a common use case of filling out registered ops

View File

@ -25,7 +25,7 @@ namespace torch_mlir {
/// Thrown on failure when details are in MLIR emitted diagnostics. /// Thrown on failure when details are in MLIR emitted diagnostics.
class mlir_diagnostic_emitted : public std::runtime_error { class mlir_diagnostic_emitted : public std::runtime_error {
public: public:
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {} mlir_diagnostic_emitted(const char* what) : std::runtime_error(what) {}
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {} mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
}; };
@ -38,37 +38,36 @@ public:
/// for Python code). /// for Python code).
/// ///
/// Returns a null type on failure and emits a diagnostic. /// Returns a null type on failure and emits a diagnostic.
MlirType getMlirTypeForTorchScalarType(MlirLocation loc, MlirType
c10::ScalarType scalarType); getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type /// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic. /// on failure and emits a diagnostic.
MlirType getMlirTypeFromTorchType(MlirLocation loc, MlirType getMlirTypeFromTorchType(
const c10::TypePtr &torchType, MlirLocation loc, const c10::TypePtr& torchType,
const ImportOptions &importOptions = {}); const ImportOptions& importOptions = {});
/// Creates a FunctionType suitable for expressing the signature of `schema`. /// Creates a FunctionType suitable for expressing the signature of `schema`.
/// ///
/// This can differ from the type inferred from the block of a /// This can differ from the type inferred from the block of a
/// torch::jit::Function due to derefinement and refinement of tensor types. /// torch::jit::Function due to derefinement and refinement of tensor types.
MlirType getFunctionTypeFromSchema(MlirContext context, MlirType getFunctionTypeFromSchema(
const c10::FunctionSchema &schema, MlirContext context, const c10::FunctionSchema& schema,
const ImportOptions &importOptions = {}); const ImportOptions& importOptions = {});
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`. /// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor, MlirAttribute
MlirLocation loc); convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc);
MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node, MlirAttribute
c10::Symbol symbol); importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol);
MlirLocation getMlirLocationFromNode(MlirContext context, MlirLocation
torch::jit::Node *node); getMlirLocationFromNode(MlirContext context, torch::jit::Node* node);
std::vector<MlirType> std::vector<MlirType> getMlirTypesFromValues(
getMlirTypesFromValues(MlirLocation loc, MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
c10::ArrayRef<torch::jit::Value *> values, const ImportOptions& importOptions = {});
const ImportOptions &importOptions = {});
std::vector<MlirValue> adjustStaticInformationForValues( std::vector<MlirValue> adjustStaticInformationForValues(
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values, MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
@ -79,11 +78,10 @@ std::vector<MlirValue> adjustStaticInformationForValues(
/// ///
/// The primary difficulty here is doing the appropriate name munging and /// The primary difficulty here is doing the appropriate name munging and
/// checking if the have a registered op. /// checking if the have a registered op.
MlirOperation createOperationFromSchema(MlirBlock appendToBlock, MlirOperation createOperationFromSchema(
MlirLocation loc, MlirBlock appendToBlock, MlirLocation loc,
const c10::FunctionSchema &schema, const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
c10::ArrayRef<MlirType> resultTypes, c10::ArrayRef<MlirValue> operands);
c10::ArrayRef<MlirValue> operands);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -14,12 +14,12 @@
#include <torch/csrc/lazy/core/lazy_graph_executor.h> #include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/shape.h> #include <torch/csrc/lazy/core/shape.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h> #include <base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h> #include <base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h> #include <base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h> #include <base_lazy_backend/utils/debug.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h> #include <base_lazy_backend/utils/exception.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h> #include <base_lazy_backend/utils/string_utils.h>
#include "backend_impl.h" #include "backend_impl.h"

View File

@ -11,10 +11,10 @@
#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/core/config.h"
#include "torch/csrc/lazy/backend/backend_interface.h" #include "torch/csrc/lazy/backend/backend_interface.h"
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h> #include <base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h> #include <base_lazy_backend/utils/string_utils.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h> #include <base_lazy_backend/utils/sys_utils.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h> #include <base_lazy_backend/utils/tensor_utils.h>
#include <exception> #include <exception>
#include <iostream> #include <iostream>

View File

@ -2,8 +2,6 @@
# Subdirectories # Subdirectories
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
add_subdirectory(csrc)
## Declare the sources of the Python module. ## Declare the sources of the Python module.
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter