mirror of https://github.com/llvm/torch-mlir
Get LTC building.
parent
fabb4d6e5d
commit
606dc45896
|
@ -26,7 +26,7 @@ __pycache__
|
||||||
bazel-*
|
bazel-*
|
||||||
|
|
||||||
# Autogenerated files
|
# Autogenerated files
|
||||||
/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated
|
/projects/ltc/csrc/base_lazy_backend/generated
|
||||||
|
|
||||||
#Docker builds
|
#Docker builds
|
||||||
build_oot/
|
build_oot/
|
||||||
|
|
|
@ -29,7 +29,6 @@ if not TORCH_INCLUDE_DIR.is_dir():
|
||||||
TORCH_INCLUDE_DIR = TORCH_DIR
|
TORCH_INCLUDE_DIR = TORCH_DIR
|
||||||
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
||||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||||
TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1"
|
|
||||||
|
|
||||||
def reindent(text, prefix=""):
|
def reindent(text, prefix=""):
|
||||||
return indent(dedent(text), prefix)
|
return indent(dedent(text), prefix)
|
||||||
|
@ -114,12 +113,12 @@ class GenTorchMlirLTC:
|
||||||
self.binary_dir = Path(binary_dir)
|
self.binary_dir = Path(binary_dir)
|
||||||
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
||||||
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
||||||
self.backend_path = TORCH_MLIR_PT1_DIR.joinpath(
|
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
||||||
"python", "torch_mlir", "csrc", "base_lazy_backend"
|
"projects", "ltc", "csrc", "base_lazy_backend"
|
||||||
)
|
)
|
||||||
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
||||||
self.generated_path = self.binary_dir.joinpath(
|
self.generated_path = self.binary_dir.joinpath(
|
||||||
"projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated"
|
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
||||||
)
|
)
|
||||||
self.generated_path.mkdir(parents=True, exist_ok=True)
|
self.generated_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -415,7 +414,7 @@ class GenTorchMlirLTC:
|
||||||
// for ops that dont have a corresponding structured kernel or shape definition
|
// for ops that dont have a corresponding structured kernel or shape definition
|
||||||
|
|
||||||
#include "shape_inference.h"
|
#include "shape_inference.h"
|
||||||
#include "torch_mlir/csrc/base_lazy_backend/utils/exception.h"
|
#include "base_lazy_backend/utils/exception.h"
|
||||||
namespace torch {{
|
namespace torch {{
|
||||||
namespace lazy {{
|
namespace lazy {{
|
||||||
{}
|
{}
|
||||||
|
@ -467,7 +466,7 @@ class GenTorchMlirLTC:
|
||||||
node_base="torch::lazy::TorchMlirNode",
|
node_base="torch::lazy::TorchMlirNode",
|
||||||
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
||||||
tensor_class=self.tensor_class,
|
tensor_class=self.tensor_class,
|
||||||
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
|
tensor_class_hdr="base_lazy_backend/tensor.h",
|
||||||
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
||||||
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
||||||
lazy_ir_generator=GenMlirLazyIr,
|
lazy_ir_generator=GenMlirLazyIr,
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
|
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
|
||||||
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
|
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
|
||||||
|
|
||||||
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
||||||
Implementations based on this abstract class will be able to specify their own compile and execution workflows.
|
Implementations based on this abstract class will be able to specify their own compile and execution workflows.
|
||||||
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
|
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ View examples [here](ltc_examples.md).
|
||||||
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
|
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
|
||||||
excluding those explicitly blacklisted in the YAML file
|
excluding those explicitly blacklisted in the YAML file
|
||||||
|
|
||||||
### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated))
|
### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated))
|
||||||
Generated files are created in this directory, which is ignored by version control.
|
Generated files are created in this directory, which is ignored by version control.
|
||||||
|
|
||||||
- `LazyIr.h`
|
- `LazyIr.h`
|
||||||
|
@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr
|
||||||
- `shape_inference.{cpp,h}`
|
- `shape_inference.{cpp,h}`
|
||||||
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
|
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
|
||||||
|
|
||||||
### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend))
|
### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend))
|
||||||
|
|
||||||
- `backend_impl.{cpp,h}`
|
- `backend_impl.{cpp,h}`
|
||||||
- Base LTC backend to setup Torch-MLIR lowering context
|
- Base LTC backend to setup Torch-MLIR lowering context
|
||||||
|
|
|
@ -56,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED
|
||||||
utils/tensor_utils.cpp
|
utils/tensor_utils.cpp
|
||||||
)
|
)
|
||||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||||
|
# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/...").
|
||||||
|
# Add both the source and generated include directories.
|
||||||
|
target_include_directories(torch_mlir_ltc_backend PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/..
|
||||||
|
)
|
||||||
|
|
||||||
add_dependencies(torch_mlir_ltc_backend
|
add_dependencies(torch_mlir_ltc_backend
|
||||||
TorchMLIRJITIRImporter
|
TorchMLIRJITIRImporter
|
||||||
|
@ -88,13 +94,13 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
@ -105,7 +111,7 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
@ -116,5 +122,5 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Pass.h"
|
#include "mlir-c/Pass.h"
|
||||||
|
|
||||||
#include "../../jit_ir_importer/csrc/function_importer.h"
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
|
#include "jit_ir_importer/function_importer.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
#include "mlir_node.h"
|
#include "mlir_node.h"
|
||||||
#include "utils/debug.h"
|
#include "utils/debug.h"
|
||||||
|
|
|
@ -100,6 +100,7 @@ endif()
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||||
add_subdirectory(torch_mlir/jit_ir_importer)
|
add_subdirectory(torch_mlir/jit_ir_importer)
|
||||||
|
add_subdirectory(torch_mlir/csrc/jit_ir_importer)
|
||||||
add_subdirectory(torch_mlir_e2e_test)
|
add_subdirectory(torch_mlir_e2e_test)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,10 @@ target_link_libraries(TorchMLIRJITIRImporter
|
||||||
TorchMLIRAggregateCAPI
|
TorchMLIRAggregateCAPI
|
||||||
${TORCH_LIBRARIES}
|
${TORCH_LIBRARIES}
|
||||||
)
|
)
|
||||||
|
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
|
||||||
|
target_include_directories(TorchMLIRJITIRImporter PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
)
|
||||||
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||||
OUTPUT_NAME lib_jit_ir_importer
|
OUTPUT_NAME lib_jit_ir_importer
|
|
@ -18,8 +18,8 @@ using namespace torch_mlir;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Prefix every line of `s` with `linePrefix`.
|
// Prefix every line of `s` with `linePrefix`.
|
||||||
static std::string indentString(const std::string &linePrefix,
|
static std::string
|
||||||
const std::string &s) {
|
indentString(const std::string& linePrefix, const std::string& s) {
|
||||||
std::stringstream is(s);
|
std::stringstream is(s);
|
||||||
std::stringstream os;
|
std::stringstream os;
|
||||||
std::string line;
|
std::string line;
|
||||||
|
@ -39,26 +39,28 @@ ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType)
|
||||||
methodAnnotations.resize(classType->methods().size());
|
methodAnnotations.resize(classType->methods().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AttributeAnnotation> &ClassAnnotation::getAttributeAnnotations() {
|
std::vector<AttributeAnnotation>& ClassAnnotation::getAttributeAnnotations() {
|
||||||
// Halfhearted attempt to ensure consistency if the class type has
|
// Halfhearted attempt to ensure consistency if the class type has
|
||||||
// been mutated.
|
// been mutated.
|
||||||
//
|
//
|
||||||
// We can't easily guard against attributes being removed and
|
// We can't easily guard against attributes being removed and
|
||||||
// then other attributes being added, or types changed, etc. without
|
// then other attributes being added, or types changed, etc. without
|
||||||
// effectively mirroring the entire ClassType.
|
// effectively mirroring the entire ClassType.
|
||||||
assert(attributeAnnotations.size() == classType->getAttributes().size() &&
|
assert(
|
||||||
"annotations out of sync. class has been mutated");
|
attributeAnnotations.size() == classType->getAttributes().size() &&
|
||||||
|
"annotations out of sync. class has been mutated");
|
||||||
|
|
||||||
return attributeAnnotations;
|
return attributeAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() {
|
std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() {
|
||||||
// Halfhearted attempt to ensure consistency if the class type has
|
// Halfhearted attempt to ensure consistency if the class type has
|
||||||
// been mutated.
|
// been mutated.
|
||||||
//
|
//
|
||||||
// We can't easily guard against methods being removed, added, or changed.
|
// We can't easily guard against methods being removed, added, or changed.
|
||||||
assert(methodAnnotations.size() == classType->methods().size() &&
|
assert(
|
||||||
"annotations out of sync. class has been mutated");
|
methodAnnotations.size() == classType->methods().size() &&
|
||||||
|
"annotations out of sync. class has been mutated");
|
||||||
|
|
||||||
return methodAnnotations;
|
return methodAnnotations;
|
||||||
}
|
}
|
||||||
|
@ -67,17 +69,17 @@ std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() {
|
||||||
// ClassAnnotator
|
// ClassAnnotator
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void exportNoneRecurse(ClassAnnotator &classAnnotator,
|
static void
|
||||||
c10::ClassType *classType) {
|
exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) {
|
||||||
ClassAnnotation &classAnnotation =
|
ClassAnnotation& classAnnotation =
|
||||||
classAnnotator.getOrCreateClassAnnotation(classType);
|
classAnnotator.getOrCreateClassAnnotation(classType);
|
||||||
for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
|
for (auto& attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
|
||||||
attributeAnnotation.isExported = false;
|
attributeAnnotation.isExported = false;
|
||||||
}
|
}
|
||||||
for (auto &methodAnnotation : classAnnotation.getMethodAnnotations()) {
|
for (auto& methodAnnotation : classAnnotation.getMethodAnnotations()) {
|
||||||
methodAnnotation.isExported = false;
|
methodAnnotation.isExported = false;
|
||||||
}
|
}
|
||||||
for (auto &classAttribute : classType->getAttributes()) {
|
for (auto& classAttribute : classType->getAttributes()) {
|
||||||
if (auto childClassType =
|
if (auto childClassType =
|
||||||
classAttribute.getType()->cast<c10::ClassType>()) {
|
classAttribute.getType()->cast<c10::ClassType>()) {
|
||||||
exportNoneRecurse(classAnnotator, childClassType.get());
|
exportNoneRecurse(classAnnotator, childClassType.get());
|
||||||
|
@ -85,20 +87,20 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::exportNone(c10::ClassType &rootClassType) {
|
void ClassAnnotator::exportNone(c10::ClassType& rootClassType) {
|
||||||
exportNoneRecurse(*this, &rootClassType);
|
exportNoneRecurse(*this, &rootClassType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
void ClassAnnotator::exportPath(
|
||||||
std::vector<std::string> exportedPath) {
|
c10::ClassType& rootClassType, std::vector<std::string> exportedPath) {
|
||||||
if (exportedPath.size() == 0) {
|
if (exportedPath.size() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Empty exported path. Can only export a property of a class.");
|
"Empty exported path. Can only export a property of a class.");
|
||||||
}
|
}
|
||||||
c10::ClassType *classType =
|
c10::ClassType* classType = getClassAtPath(
|
||||||
getClassAtPath(&rootClassType, c10::ArrayRef<std::string>(exportedPath)
|
&rootClassType, c10::ArrayRef<std::string>(exportedPath)
|
||||||
.slice(0, exportedPath.size() - 1)
|
.slice(0, exportedPath.size() - 1)
|
||||||
.vec());
|
.vec());
|
||||||
|
|
||||||
if (!classType->findAttribute(exportedPath.back()) &&
|
if (!classType->findAttribute(exportedPath.back()) &&
|
||||||
!classType->findMethod(exportedPath.back())) {
|
!classType->findMethod(exportedPath.back())) {
|
||||||
|
@ -108,10 +110,10 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
||||||
<< exportedPath.back() << "'";
|
<< exportedPath.back() << "'";
|
||||||
throw std::invalid_argument(ss.str());
|
throw std::invalid_argument(ss.str());
|
||||||
}
|
}
|
||||||
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
|
ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType);
|
||||||
std::vector<AttributeAnnotation> &attributeAnnotations =
|
std::vector<AttributeAnnotation>& attributeAnnotations =
|
||||||
classAnnotation.getAttributeAnnotations();
|
classAnnotation.getAttributeAnnotations();
|
||||||
const std::vector<c10::ClassAttribute> &classAttributes =
|
const std::vector<c10::ClassAttribute>& classAttributes =
|
||||||
classType->getAttributes();
|
classType->getAttributes();
|
||||||
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
||||||
if (classAttributes[i].getName() == exportedPath.back()) {
|
if (classAttributes[i].getName() == exportedPath.back()) {
|
||||||
|
@ -119,9 +121,9 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MethodAnnotation> &methodAnnotations =
|
std::vector<MethodAnnotation>& methodAnnotations =
|
||||||
classAnnotation.getMethodAnnotations();
|
classAnnotation.getMethodAnnotations();
|
||||||
const std::vector<torch::jit::Function *> &methods = classType->methods();
|
const std::vector<torch::jit::Function*>& methods = classType->methods();
|
||||||
for (int i = 0, e = methods.size(); i != e; i++) {
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
||||||
if (methods[i]->name() == exportedPath.back()) {
|
if (methods[i]->name() == exportedPath.back()) {
|
||||||
methodAnnotations[i].isExported = true;
|
methodAnnotations[i].isExported = true;
|
||||||
|
@ -129,12 +131,12 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const ClassAnnotationMap &ClassAnnotator::getAnnotationMap() {
|
const ClassAnnotationMap& ClassAnnotator::getAnnotationMap() {
|
||||||
return classAnnotations;
|
return classAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
ClassAnnotation &
|
ClassAnnotation&
|
||||||
ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) {
|
||||||
auto className = classType->name()->qualifiedName();
|
auto className = classType->name()->qualifiedName();
|
||||||
auto it = classAnnotations.find(className);
|
auto it = classAnnotations.find(className);
|
||||||
if (it == classAnnotations.end()) {
|
if (it == classAnnotations.end()) {
|
||||||
|
@ -149,39 +151,39 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
|
||||||
return *it->second;
|
return *it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
static void fillArgAnnotations(
|
||||||
std::vector<ArgAnnotation> argAnnotations,
|
MethodAnnotation& methodAnnotation,
|
||||||
torch::jit::Function *function) {
|
std::vector<ArgAnnotation> argAnnotations, torch::jit::Function* function) {
|
||||||
if (argAnnotations.size() != function->num_inputs()) {
|
if (argAnnotations.size() != function->num_inputs()) {
|
||||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
throw std::invalid_argument("Arg annotations should have one entry per "
|
||||||
"function parameter (including self).");
|
"function parameter (including self).");
|
||||||
}
|
}
|
||||||
if (!methodAnnotation.argAnnotations.has_value()) {
|
if (!methodAnnotation.argAnnotations.has_value()) {
|
||||||
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
methodAnnotation.argAnnotations.emplace(
|
||||||
ArgAnnotation{});
|
function->num_inputs(), ArgAnnotation{});
|
||||||
}
|
}
|
||||||
|
|
||||||
methodAnnotation.argAnnotations = argAnnotations;
|
methodAnnotation.argAnnotations = argAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
void ClassAnnotator::annotateArgs(
|
||||||
std::vector<std::string> path,
|
c10::ClassType& rootClassType, std::vector<std::string> path,
|
||||||
std::vector<ArgAnnotation> argAnnotations) {
|
std::vector<ArgAnnotation> argAnnotations) {
|
||||||
if (path.size() == 0) {
|
if (path.size() == 0) {
|
||||||
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
||||||
"shapes/dtypes of a method of a class.");
|
"shapes/dtypes of a method of a class.");
|
||||||
}
|
}
|
||||||
c10::ClassType *classType = getClassAtPath(
|
c10::ClassType* classType = getClassAtPath(
|
||||||
&rootClassType,
|
&rootClassType,
|
||||||
c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec());
|
c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec());
|
||||||
|
|
||||||
// Throw error if no method on the class of the specified name.
|
// Throw error if no method on the class of the specified name.
|
||||||
torch::jit::Function *function = &classType->getMethod(path.back());
|
torch::jit::Function* function = &classType->getMethod(path.back());
|
||||||
|
|
||||||
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
|
ClassAnnotation& classAnnotation = getOrCreateClassAnnotation(classType);
|
||||||
std::vector<MethodAnnotation> &methodAnnotations =
|
std::vector<MethodAnnotation>& methodAnnotations =
|
||||||
classAnnotation.getMethodAnnotations();
|
classAnnotation.getMethodAnnotations();
|
||||||
const std::vector<torch::jit::Function *> &methods = classType->methods();
|
const std::vector<torch::jit::Function*>& methods = classType->methods();
|
||||||
for (int i = 0, e = methods.size(); i != e; i++) {
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
||||||
if (methods[i]->name() == path.back()) {
|
if (methods[i]->name() == path.back()) {
|
||||||
fillArgAnnotations(methodAnnotations[i], argAnnotations, function);
|
fillArgAnnotations(methodAnnotations[i], argAnnotations, function);
|
||||||
|
@ -191,9 +193,9 @@ void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType,
|
c10::ClassType* ClassAnnotator::getClassAtPath(
|
||||||
std::vector<std::string> path) {
|
c10::ClassType* rootClassType, std::vector<std::string> path) {
|
||||||
c10::ClassType *classType = rootClassType;
|
c10::ClassType* classType = rootClassType;
|
||||||
// Reverse so that pop_back gives us the initial atoms first.
|
// Reverse so that pop_back gives us the initial atoms first.
|
||||||
std::reverse(path.begin(), path.end());
|
std::reverse(path.begin(), path.end());
|
||||||
while (!path.empty()) {
|
while (!path.empty()) {
|
||||||
|
@ -215,8 +217,8 @@ c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Helper methods
|
// Helper methods
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
MethodAnnotation *
|
MethodAnnotation*
|
||||||
ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) {
|
ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function* function) {
|
||||||
auto it = functionToMethodMap.find(function);
|
auto it = functionToMethodMap.find(function);
|
||||||
if (it == functionToMethodMap.end()) {
|
if (it == functionToMethodMap.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -228,7 +230,7 @@ ClassAnnotator::getMethodAnnotationForFunction(torch::jit::Function *function) {
|
||||||
// toString methods
|
// toString methods
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
std::string AttributeAnnotation::toString(const std::string &name) {
|
std::string AttributeAnnotation::toString(const std::string& name) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "AttributeAnnotation('" << name << "') {\n";
|
ss << "AttributeAnnotation('" << name << "') {\n";
|
||||||
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
|
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
|
||||||
|
@ -259,7 +261,7 @@ std::string ArgAnnotation::toString(int argIndex) {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string MethodAnnotation::toString(const std::string &name) {
|
std::string MethodAnnotation::toString(const std::string& name) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "MethodAnnotation('" << name << "') {\n";
|
ss << "MethodAnnotation('" << name << "') {\n";
|
||||||
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
|
ss << " isExported = " << (isExported ? "true" : "false") << "\n";
|
||||||
|
@ -280,13 +282,13 @@ std::string ClassAnnotation::toString() {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n";
|
ss << "ClassAnnotation('" << classType->name()->qualifiedName() << "') {\n";
|
||||||
|
|
||||||
const std::vector<c10::ClassAttribute> &classAttributes =
|
const std::vector<c10::ClassAttribute>& classAttributes =
|
||||||
classType->getAttributes();
|
classType->getAttributes();
|
||||||
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
||||||
ss << indentString(
|
ss << indentString(
|
||||||
" ", attributeAnnotations[i].toString(classAttributes[i].getName()));
|
" ", attributeAnnotations[i].toString(classAttributes[i].getName()));
|
||||||
}
|
}
|
||||||
const std::vector<torch::jit::Function *> &methods = classType->methods();
|
const std::vector<torch::jit::Function*>& methods = classType->methods();
|
||||||
for (int i = 0, e = methods.size(); i != e; i++) {
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
||||||
ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name()));
|
ss << indentString(" ", methodAnnotations[i].toString(methods[i]->name()));
|
||||||
}
|
}
|
||||||
|
@ -297,7 +299,7 @@ std::string ClassAnnotation::toString() {
|
||||||
std::string ClassAnnotator::toString() {
|
std::string ClassAnnotator::toString() {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "ClassAnnotator {\n";
|
ss << "ClassAnnotator {\n";
|
||||||
for (auto &p : classAnnotations) {
|
for (auto& p : classAnnotations) {
|
||||||
ss << indentString(" ", p.second->toString());
|
ss << indentString(" ", p.second->toString());
|
||||||
}
|
}
|
||||||
ss << "}\n";
|
ss << "}\n";
|
|
@ -34,7 +34,7 @@ struct AttributeAnnotation {
|
||||||
// can be externally accessed.
|
// can be externally accessed.
|
||||||
bool isExported = true;
|
bool isExported = true;
|
||||||
|
|
||||||
std::string toString(const std::string &name);
|
std::string toString(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
// An annotation of an argument of a method.
|
// An annotation of an argument of a method.
|
||||||
|
@ -80,7 +80,7 @@ struct MethodAnnotation {
|
||||||
// large printout of the default ArgAnnotation for every method.
|
// large printout of the default ArgAnnotation for every method.
|
||||||
c10::optional<std::vector<ArgAnnotation>> argAnnotations;
|
c10::optional<std::vector<ArgAnnotation>> argAnnotations;
|
||||||
|
|
||||||
std::string toString(const std::string &name);
|
std::string toString(const std::string& name);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Annotations on a c10::ClassType.
|
// Annotations on a c10::ClassType.
|
||||||
|
@ -107,10 +107,10 @@ public:
|
||||||
|
|
||||||
// Get the attribute annotations.
|
// Get the attribute annotations.
|
||||||
// The length and order is the same as `classType->getAttributes()`.
|
// The length and order is the same as `classType->getAttributes()`.
|
||||||
std::vector<AttributeAnnotation> &getAttributeAnnotations();
|
std::vector<AttributeAnnotation>& getAttributeAnnotations();
|
||||||
// Get the method annotations.
|
// Get the method annotations.
|
||||||
// The length and order is the same as `classType->methods()`.
|
// The length and order is the same as `classType->methods()`.
|
||||||
std::vector<MethodAnnotation> &getMethodAnnotations();
|
std::vector<MethodAnnotation>& getMethodAnnotations();
|
||||||
|
|
||||||
std::string toString();
|
std::string toString();
|
||||||
|
|
||||||
|
@ -141,14 +141,14 @@ public:
|
||||||
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
|
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
|
||||||
// have a submodule `a` and that submodule should have a method or attribute
|
// have a submodule `a` and that submodule should have a method or attribute
|
||||||
// `b`.
|
// `b`.
|
||||||
void exportPath(c10::ClassType &rootClassType,
|
void exportPath(
|
||||||
std::vector<std::string> exportedPath);
|
c10::ClassType& rootClassType, std::vector<std::string> exportedPath);
|
||||||
// Mark everything as not-exported.
|
// Mark everything as not-exported.
|
||||||
//
|
//
|
||||||
// This is kind of useless by itself, but together with `exportPath` allows
|
// This is kind of useless by itself, but together with `exportPath` allows
|
||||||
// exporting a subset of known names out of a larger collection of unknown
|
// exporting a subset of known names out of a larger collection of unknown
|
||||||
// names.
|
// names.
|
||||||
void exportNone(c10::ClassType &rootClassType);
|
void exportNone(c10::ClassType& rootClassType);
|
||||||
|
|
||||||
// Annotate shapes and dtypes of the arguments of a method at path `path` from
|
// Annotate shapes and dtypes of the arguments of a method at path `path` from
|
||||||
// `rootClassType`.
|
// `rootClassType`.
|
||||||
|
@ -159,23 +159,23 @@ public:
|
||||||
// a "has value semantics" boolean.
|
// a "has value semantics" boolean.
|
||||||
// These will be put into an `ArgAnnotation` struct -- see there for
|
// These will be put into an `ArgAnnotation` struct -- see there for
|
||||||
// precise definitions of the promised semantics of each entry.
|
// precise definitions of the promised semantics of each entry.
|
||||||
void annotateArgs(c10::ClassType &rootClassType,
|
void annotateArgs(
|
||||||
std::vector<std::string> path,
|
c10::ClassType& rootClassType, std::vector<std::string> path,
|
||||||
std::vector<ArgAnnotation> argAnnotations);
|
std::vector<ArgAnnotation> argAnnotations);
|
||||||
|
|
||||||
// The annotations collected so far.
|
// The annotations collected so far.
|
||||||
const ClassAnnotationMap &getAnnotationMap();
|
const ClassAnnotationMap& getAnnotationMap();
|
||||||
|
|
||||||
// Get the ClassAnnotation corresponding to `classType`.
|
// Get the ClassAnnotation corresponding to `classType`.
|
||||||
ClassAnnotation &getOrCreateClassAnnotation(c10::ClassType *classType);
|
ClassAnnotation& getOrCreateClassAnnotation(c10::ClassType* classType);
|
||||||
|
|
||||||
// Helper to find the MethodAnnotation corresponding to a
|
// Helper to find the MethodAnnotation corresponding to a
|
||||||
// torch::jit::Function, or null if not found.
|
// torch::jit::Function, or null if not found.
|
||||||
//
|
//
|
||||||
// Users could in principle scan all annotations to find this, but it's more
|
// Users could in principle scan all annotations to find this, but it's more
|
||||||
// efficient to maintain the reverse mapping directly.
|
// efficient to maintain the reverse mapping directly.
|
||||||
MethodAnnotation *
|
MethodAnnotation*
|
||||||
getMethodAnnotationForFunction(torch::jit::Function *function);
|
getMethodAnnotationForFunction(torch::jit::Function* function);
|
||||||
|
|
||||||
std::string toString();
|
std::string toString();
|
||||||
|
|
||||||
|
@ -183,11 +183,11 @@ private:
|
||||||
// Traverse `path` starting from `rootClassType` to find the ClassType
|
// Traverse `path` starting from `rootClassType` to find the ClassType
|
||||||
// of a presumed nested submodule. Throw an error if there is no such
|
// of a presumed nested submodule. Throw an error if there is no such
|
||||||
// submodule.
|
// submodule.
|
||||||
c10::ClassType *getClassAtPath(c10::ClassType *rootClassType,
|
c10::ClassType*
|
||||||
std::vector<std::string> path);
|
getClassAtPath(c10::ClassType* rootClassType, std::vector<std::string> path);
|
||||||
ClassAnnotationMap classAnnotations;
|
ClassAnnotationMap classAnnotations;
|
||||||
// Reverse mapping used to service getMethodAnnotationForFunction.
|
// Reverse mapping used to service getMethodAnnotationForFunction.
|
||||||
std::unordered_map<torch::jit::Function *, MethodAnnotation *>
|
std::unordered_map<torch::jit::Function*, MethodAnnotation*>
|
||||||
functionToMethodMap;
|
functionToMethodMap;
|
||||||
};
|
};
|
||||||
|
|
|
@ -18,7 +18,7 @@ using namespace torch_mlir;
|
||||||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||||
if (THPDtype_Check(obj.ptr())) {
|
if (THPDtype_Check(obj.ptr())) {
|
||||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
THPDtype* dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||||
return dtype->scalar_type;
|
return dtype->scalar_type;
|
||||||
}
|
}
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
@ -48,16 +48,17 @@ static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
|
||||||
return argAnnotations;
|
return argAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
void torch_mlir::initClassAnnotatorBindings(py::module& m) {
|
||||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def("exportPath", &ClassAnnotator::exportPath)
|
.def("exportPath", &ClassAnnotator::exportPath)
|
||||||
.def("exportNone", &ClassAnnotator::exportNone)
|
.def("exportNone", &ClassAnnotator::exportNone)
|
||||||
.def("annotateArgs",
|
.def(
|
||||||
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
|
"annotateArgs",
|
||||||
std::vector<std::string> path, py::list argAnnotations) {
|
[&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType,
|
||||||
cls_annotator.annotateArgs(rootClassType, path,
|
std::vector<std::string> path, py::list argAnnotations) {
|
||||||
getArgAnnotations(argAnnotations));
|
cls_annotator.annotateArgs(
|
||||||
})
|
rootClassType, path, getArgAnnotations(argAnnotations));
|
||||||
|
})
|
||||||
.def("__repr__", &ClassAnnotator::toString);
|
.def("__repr__", &ClassAnnotator::toString);
|
||||||
}
|
}
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
void initClassAnnotatorBindings(py::module &m);
|
void initClassAnnotatorBindings(py::module& m);
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
|
@ -21,9 +21,9 @@
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
MlirContext context, torch::jit::Function *function,
|
MlirContext context, torch::jit::Function* function,
|
||||||
std::function<MlirAttribute(int)> getArgAttribute,
|
std::function<MlirAttribute(int)> getArgAttribute,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions& importOptions) {
|
||||||
// Useful for debugging:
|
// Useful for debugging:
|
||||||
// graph->dump();
|
// graph->dump();
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
@ -63,10 +63,11 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
}
|
}
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(appendToBlock, "func.return", loc,
|
createMlirOperationAtEnd(
|
||||||
adjustStaticInformationForValues(
|
appendToBlock, "func.return", loc,
|
||||||
appendToBlock, loc, yieldedValues, resultTypes,
|
adjustStaticInformationForValues(
|
||||||
/*userAllowsRefinement=*/false));
|
appendToBlock, loc, yieldedValues, resultTypes,
|
||||||
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
MlirBlock block = importBlock(
|
MlirBlock block = importBlock(
|
||||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
|
@ -40,10 +40,10 @@ namespace torch_mlir {
|
||||||
/// null MlirAttribute is returned, no attribute will be attached to that
|
/// null MlirAttribute is returned, no attribute will be attached to that
|
||||||
/// argument.
|
/// argument.
|
||||||
MlirOperation importJitFunctionAsFuncOp(
|
MlirOperation importJitFunctionAsFuncOp(
|
||||||
MlirContext context, torch::jit::Function *function,
|
MlirContext context, torch::jit::Function* function,
|
||||||
std::function<MlirAttribute(int)> getArgAttribute =
|
std::function<MlirAttribute(int)> getArgAttribute =
|
||||||
[](int) -> MlirAttribute { return {nullptr}; },
|
[](int) -> MlirAttribute { return {nullptr}; },
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions& importOptions = {});
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -50,9 +50,9 @@ static py::list getRegisteredOps() {
|
||||||
// since the JIT has its own dispatch mechanism that it uses to implement
|
// since the JIT has its own dispatch mechanism that it uses to implement
|
||||||
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||||
// as `aten::__is__`.
|
// as `aten::__is__`.
|
||||||
for (const std::shared_ptr<torch::jit::Operator> &op :
|
for (const std::shared_ptr<torch::jit::Operator>& op :
|
||||||
torch::jit::getAllOperators()) {
|
torch::jit::getAllOperators()) {
|
||||||
const c10::FunctionSchema &schema = op->schema();
|
const c10::FunctionSchema& schema = op->schema();
|
||||||
|
|
||||||
py::dict record;
|
py::dict record;
|
||||||
{
|
{
|
||||||
|
@ -69,7 +69,7 @@ static py::list getRegisteredOps() {
|
||||||
|
|
||||||
py::list arguments;
|
py::list arguments;
|
||||||
py::list returns;
|
py::list returns;
|
||||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
auto addArgument = [](py::list& container, const c10::Argument& arg) {
|
||||||
py::dict argRecord;
|
py::dict argRecord;
|
||||||
argRecord["name"] = arg.name();
|
argRecord["name"] = arg.name();
|
||||||
argRecord["type"] = arg.type()->str();
|
argRecord["type"] = arg.type()->str();
|
||||||
|
@ -87,10 +87,10 @@ static py::list getRegisteredOps() {
|
||||||
py::dict aliasInfo;
|
py::dict aliasInfo;
|
||||||
py::list before;
|
py::list before;
|
||||||
py::list after;
|
py::list after;
|
||||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
for (auto& symbol : arg.alias_info()->beforeSets()) {
|
||||||
before.append(std::string(symbol.toQualString()));
|
before.append(std::string(symbol.toQualString()));
|
||||||
}
|
}
|
||||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
for (auto& symbol : arg.alias_info()->afterSets()) {
|
||||||
after.append(std::string(symbol.toQualString()));
|
after.append(std::string(symbol.toQualString()));
|
||||||
}
|
}
|
||||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||||
|
@ -101,10 +101,10 @@ static py::list getRegisteredOps() {
|
||||||
|
|
||||||
container.append(std::move(argRecord));
|
container.append(std::move(argRecord));
|
||||||
};
|
};
|
||||||
for (auto &argument : schema.arguments()) {
|
for (auto& argument : schema.arguments()) {
|
||||||
addArgument(arguments, argument);
|
addArgument(arguments, argument);
|
||||||
}
|
}
|
||||||
for (auto &returnArg : schema.returns()) {
|
for (auto& returnArg : schema.returns()) {
|
||||||
addArgument(returns, returnArg);
|
addArgument(returns, returnArg);
|
||||||
}
|
}
|
||||||
record["arguments"] = std::move(arguments);
|
record["arguments"] = std::move(arguments);
|
||||||
|
@ -115,6 +115,6 @@ static py::list getRegisteredOps() {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) {
|
void torch_mlir::initGetRegisteredOpsBindings(py::module& m) {
|
||||||
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
||||||
}
|
}
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
void initGetRegisteredOpsBindings(py::module &m);
|
void initGetRegisteredOpsBindings(py::module& m);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -14,11 +14,13 @@ namespace py = pybind11;
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
void torch_mlir::initImportOptionsBindings(py::module &m) {
|
void torch_mlir::initImportOptionsBindings(py::module& m) {
|
||||||
py::class_<ImportOptions>(m, "ImportOptions")
|
py::class_<ImportOptions>(m, "ImportOptions")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def_readwrite("assumeTensorsHaveValueSemantics",
|
.def_readwrite(
|
||||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
"assumeTensorsHaveValueSemantics",
|
||||||
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
.def_readwrite(
|
||||||
|
"ignoreExistingTensorShapesAndDtypes",
|
||||||
|
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
||||||
}
|
}
|
|
@ -13,7 +13,7 @@
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
void initImportOptionsBindings(pybind11::module &m);
|
void initImportOptionsBindings(pybind11::module& m);
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H
|
|
@ -49,10 +49,10 @@ using namespace torch_mlir;
|
||||||
// throw an error on).
|
// throw an error on).
|
||||||
namespace {
|
namespace {
|
||||||
struct IValueHasher {
|
struct IValueHasher {
|
||||||
size_t operator()(const c10::IValue &ivalue) const {
|
size_t operator()(const c10::IValue& ivalue) const {
|
||||||
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
|
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
|
||||||
return std::hash<const void *>()(
|
return std::hash<const void*>()(
|
||||||
static_cast<const void *>(ivalue.internalToPointer()));
|
static_cast<const void*>(ivalue.internalToPointer()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return c10::IValue::hash(ivalue);
|
return c10::IValue::hash(ivalue);
|
||||||
|
@ -65,7 +65,7 @@ struct IValueHasher {
|
||||||
// such as when tracing). Can we do better?
|
// such as when tracing). Can we do better?
|
||||||
namespace {
|
namespace {
|
||||||
struct IValueEq {
|
struct IValueEq {
|
||||||
bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const {
|
bool operator()(const c10::IValue& lhs, const c10::IValue& rhs) const {
|
||||||
return lhs.isSameIdentity(rhs);
|
return lhs.isSameIdentity(rhs);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -99,8 +99,9 @@ namespace {
|
||||||
/// (PyTorch allows this!).
|
/// (PyTorch allows this!).
|
||||||
class IValueImporter {
|
class IValueImporter {
|
||||||
public:
|
public:
|
||||||
IValueImporter(MlirBlock importBlock, MlirContext context,
|
IValueImporter(
|
||||||
ClassAnnotator &annotator, const ImportOptions &importOptions)
|
MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator,
|
||||||
|
const ImportOptions& importOptions)
|
||||||
: importBlock(importBlock), context(context), annotator(annotator),
|
: importBlock(importBlock), context(context), annotator(annotator),
|
||||||
importOptions(importOptions) {}
|
importOptions(importOptions) {}
|
||||||
|
|
||||||
|
@ -110,15 +111,16 @@ private:
|
||||||
MlirValue rawImportIValue(c10::IValue ivalue);
|
MlirValue rawImportIValue(c10::IValue ivalue);
|
||||||
MlirValue importTensor(c10::IValue ivalue);
|
MlirValue importTensor(c10::IValue ivalue);
|
||||||
MlirValue importModule(torch::jit::Module jitModule);
|
MlirValue importModule(torch::jit::Module jitModule);
|
||||||
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
void importMethod(
|
||||||
const MethodAnnotation &methodAnnotation);
|
torch::jit::Function* function, MlirBlock classTypeBody,
|
||||||
void importClassType(c10::ClassType *classType);
|
const MethodAnnotation& methodAnnotation);
|
||||||
void importCompilationUnit(torch::jit::CompilationUnit *cu);
|
void importClassType(c10::ClassType* classType);
|
||||||
|
void importCompilationUnit(torch::jit::CompilationUnit* cu);
|
||||||
|
|
||||||
MlirBlock importBlock;
|
MlirBlock importBlock;
|
||||||
MlirContext context;
|
MlirContext context;
|
||||||
ClassAnnotator &annotator;
|
ClassAnnotator& annotator;
|
||||||
const ImportOptions &importOptions;
|
const ImportOptions& importOptions;
|
||||||
|
|
||||||
// Map tracking already-imported values.
|
// Map tracking already-imported values.
|
||||||
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
||||||
|
@ -129,16 +131,16 @@ private:
|
||||||
// e.g. methods (the function names are meaningful and match with Python's
|
// e.g. methods (the function names are meaningful and match with Python's
|
||||||
// module hierarchy, with the exception of `__main__` being replaced with
|
// module hierarchy, with the exception of `__main__` being replaced with
|
||||||
// `__torch__`).
|
// `__torch__`).
|
||||||
torch::jit::CompilationUnit *compilationUnit = nullptr;
|
torch::jit::CompilationUnit* compilationUnit = nullptr;
|
||||||
|
|
||||||
// Used to detect potentially aliasing tensors.
|
// Used to detect potentially aliasing tensors.
|
||||||
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
std::unordered_set<c10::StorageImpl*> seenStorageImpls;
|
||||||
// The set of ClassType's that have already been imported.
|
// The set of ClassType's that have already been imported.
|
||||||
//
|
//
|
||||||
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
||||||
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
||||||
// them with the MlirOperation that they import into.
|
// them with the MlirOperation that they import into.
|
||||||
std::unordered_set<c10::ClassType *> classTypes;
|
std::unordered_set<c10::ClassType*> classTypes;
|
||||||
// The stack of attribute names we have traversed to reach the current IValue.
|
// The stack of attribute names we have traversed to reach the current IValue.
|
||||||
// Used for diagnostics.
|
// Used for diagnostics.
|
||||||
std::vector<std::string> attributeNameStack;
|
std::vector<std::string> attributeNameStack;
|
||||||
|
@ -190,7 +192,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
mlirRegionAppendOwnedBlock(
|
||||||
|
nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
InserterGuard inserterGuard(importBlock, nnModule);
|
InserterGuard inserterGuard(importBlock, nnModule);
|
||||||
|
|
||||||
|
@ -198,13 +201,14 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
rootModuleName = moduleTypeName;
|
rootModuleName = moduleTypeName;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
const std::vector<c10::IValue>& slots = currentModule._ivalue()->slots();
|
||||||
const std::vector<c10::ClassAttribute> &classAttributes =
|
const std::vector<c10::ClassAttribute>& classAttributes =
|
||||||
currentModule.type()->getAttributes();
|
currentModule.type()->getAttributes();
|
||||||
assert(slots.size() == classAttributes.size() &&
|
assert(
|
||||||
"mismatch between object and type!");
|
slots.size() == classAttributes.size() &&
|
||||||
|
"mismatch between object and type!");
|
||||||
for (int i = 0, e = slots.size(); i < e; i++) {
|
for (int i = 0, e = slots.size(); i < e; i++) {
|
||||||
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
const c10::ClassAttribute& classAttribute = classAttributes[i];
|
||||||
attributeNameStack.push_back(classAttribute.getName());
|
attributeNameStack.push_back(classAttribute.getName());
|
||||||
MlirValue slotValue = importIValue(slots[i]);
|
MlirValue slotValue = importIValue(slots[i]);
|
||||||
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
||||||
|
@ -231,7 +235,7 @@ MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
// Reject potentially aliased tensors.
|
// Reject potentially aliased tensors.
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
c10::StorageImpl *storageImpl =
|
c10::StorageImpl* storageImpl =
|
||||||
ivalue.toTensor().storage().unsafeGetStorageImpl();
|
ivalue.toTensor().storage().unsafeGetStorageImpl();
|
||||||
if (!seenStorageImpls.insert(storageImpl).second) {
|
if (!seenStorageImpls.insert(storageImpl).second) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
|
@ -257,8 +261,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirType type = torchMlirTorchBoolTypeGet(context);
|
MlirType type = torchMlirTorchBoolTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.bool", loc, type,
|
importBlock, "torch.constant.bool", loc, type,
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute(
|
||||||
mlirBoolAttrGet(context, ivalue.toBool())));
|
"value", mlirBoolAttrGet(context, ivalue.toBool())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isDouble()) {
|
if (ivalue.isDouble()) {
|
||||||
|
@ -266,23 +270,23 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.float", loc, type,
|
importBlock, "torch.constant.float", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
|
"value", mlirFloatAttrDoubleGet(
|
||||||
ivalue.toDouble())));
|
context, mlirF64TypeGet(context), ivalue.toDouble())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isInt()) {
|
if (ivalue.isInt()) {
|
||||||
MlirType type = torchMlirTorchIntTypeGet(context);
|
MlirType type = torchMlirTorchIntTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.int", loc, type,
|
importBlock, "torch.constant.int", loc, type,
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute(
|
||||||
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64),
|
"value", mlirIntegerAttrGet(
|
||||||
ivalue.toInt())));
|
mlirIntegerTypeGet(context, 64), ivalue.toInt())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isList()) {
|
if (ivalue.isList()) {
|
||||||
c10::List<c10::IValue> list = ivalue.toList();
|
c10::List<c10::IValue> list = ivalue.toList();
|
||||||
std::vector<MlirValue> elems;
|
std::vector<MlirValue> elems;
|
||||||
for (const c10::IValue &elem : list) {
|
for (const c10::IValue& elem : list) {
|
||||||
elems.push_back(importIValue(elem));
|
elems.push_back(importIValue(elem));
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
@ -312,7 +316,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
auto list = ivalue.toTuple()->elements();
|
auto list = ivalue.toTuple()->elements();
|
||||||
std::vector<MlirValue> operands;
|
std::vector<MlirValue> operands;
|
||||||
std::vector<MlirType> types;
|
std::vector<MlirType> types;
|
||||||
for (const c10::IValue &elem : list) {
|
for (const c10::IValue& elem : list) {
|
||||||
MlirValue operand = importIValue(elem);
|
MlirValue operand = importIValue(elem);
|
||||||
operands.push_back(operand);
|
operands.push_back(operand);
|
||||||
types.push_back(mlirValueGetType(operand));
|
types.push_back(mlirValueGetType(operand));
|
||||||
|
@ -335,14 +339,14 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
torchMlirTorchStringTypeGet(context),
|
torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirStringAttrGet(context,
|
mlirStringAttrGet(
|
||||||
toMlirStringRef(ivalue.toString()->string()))));
|
context, toMlirStringRef(ivalue.toString()->string()))));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
MlirOperation operation =
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
importBlock, "torch.constant.none", loc,
|
||||||
torchMlirTorchNoneTypeGet(context));
|
torchMlirTorchNoneTypeGet(context));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isCustomClass()) {
|
if (ivalue.isCustomClass()) {
|
||||||
|
@ -436,12 +440,12 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
return tensorValue;
|
return tensorValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IValueImporter::importMethod(torch::jit::Function *function,
|
void IValueImporter::importMethod(
|
||||||
MlirBlock classTypeBody,
|
torch::jit::Function* function, MlirBlock classTypeBody,
|
||||||
const MethodAnnotation &methodAnnotation) {
|
const MethodAnnotation& methodAnnotation) {
|
||||||
// The function's name becomes the MLIR symbol table name of the imported func
|
// The function's name becomes the MLIR symbol table name of the imported func
|
||||||
// when we import the compilation unit.
|
// when we import the compilation unit.
|
||||||
const std::string &symName = function->qualname().qualifiedName();
|
const std::string& symName = function->qualname().qualifiedName();
|
||||||
MlirAttribute functionSymbolRef =
|
MlirAttribute functionSymbolRef =
|
||||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
|
||||||
|
|
||||||
|
@ -457,7 +461,7 @@ void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
|
toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IValueImporter::importClassType(c10::ClassType *classType) {
|
void IValueImporter::importClassType(c10::ClassType* classType) {
|
||||||
if (!classTypes.insert(classType).second) {
|
if (!classTypes.insert(classType).second) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -475,13 +479,13 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
|
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
||||||
|
|
||||||
ClassAnnotation &classAnnotation =
|
ClassAnnotation& classAnnotation =
|
||||||
annotator.getOrCreateClassAnnotation(classType);
|
annotator.getOrCreateClassAnnotation(classType);
|
||||||
|
|
||||||
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
|
const auto& attributeAnnotations = classAnnotation.getAttributeAnnotations();
|
||||||
const auto &classAttributes = classType->getAttributes();
|
const auto& classAttributes = classType->getAttributes();
|
||||||
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
||||||
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
const c10::ClassAttribute& classAttribute = classAttributes[i];
|
||||||
c10::optional<MlirNamedAttribute> isPrivate;
|
c10::optional<MlirNamedAttribute> isPrivate;
|
||||||
if (!attributeAnnotations[i].isExported) {
|
if (!attributeAnnotations[i].isExported) {
|
||||||
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
||||||
|
@ -491,13 +495,14 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"name", mlirStringAttrGet(
|
"name", mlirStringAttrGet(
|
||||||
context, toMlirStringRef(classAttribute.getName()))),
|
context, toMlirStringRef(classAttribute.getName()))),
|
||||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
toMlirNamedAttribute(
|
||||||
loc, classAttribute.getType(), importOptions))),
|
"type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||||
|
loc, classAttribute.getType(), importOptions))),
|
||||||
isPrivate);
|
isPrivate);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
|
const auto& methodAnnotations = classAnnotation.getMethodAnnotations();
|
||||||
const auto &methods = classType->methods();
|
const auto& methods = classType->methods();
|
||||||
for (int i = 0, e = methods.size(); i != e; i++) {
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
||||||
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
|
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
|
||||||
}
|
}
|
||||||
|
@ -505,7 +510,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
|
||||||
if (compilationUnit == nullptr) {
|
if (compilationUnit == nullptr) {
|
||||||
compilationUnit = cu;
|
compilationUnit = cu;
|
||||||
} else {
|
} else {
|
||||||
|
@ -524,14 +529,14 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (torch::jit::Function *function : cu->get_functions()) {
|
for (torch::jit::Function* function : cu->get_functions()) {
|
||||||
// Useful for debugging errors in free functions that end up being
|
// Useful for debugging errors in free functions that end up being
|
||||||
// unused. These can be missing when round-tripping through the on-disk
|
// unused. These can be missing when round-tripping through the on-disk
|
||||||
// format, even though they still cause import issues when importing
|
// format, even though they still cause import issues when importing
|
||||||
// through the larger Python session where they originate.
|
// through the larger Python session where they originate.
|
||||||
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
|
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
|
||||||
// std::cerr << *torch::jit::toGraphFunction(function).graph();
|
// std::cerr << *torch::jit::toGraphFunction(function).graph();
|
||||||
MethodAnnotation *annotation =
|
MethodAnnotation* annotation =
|
||||||
annotator.getMethodAnnotationForFunction(function);
|
annotator.getMethodAnnotationForFunction(function);
|
||||||
MlirOperation func = importJitFunctionAsFuncOp(
|
MlirOperation func = importJitFunctionAsFuncOp(
|
||||||
context, function,
|
context, function,
|
||||||
|
@ -539,9 +544,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
if (!annotation || !annotation->argAnnotations.has_value()) {
|
if (!annotation || !annotation->argAnnotations.has_value()) {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
c10::optional<std::vector<int64_t>> &maybeShape =
|
c10::optional<std::vector<int64_t>>& maybeShape =
|
||||||
annotation->argAnnotations.value()[argIndex].shape;
|
annotation->argAnnotations.value()[argIndex].shape;
|
||||||
c10::optional<c10::ScalarType> &maybeDtype =
|
c10::optional<c10::ScalarType>& maybeDtype =
|
||||||
annotation->argAnnotations.value()[argIndex].dtype;
|
annotation->argAnnotations.value()[argIndex].dtype;
|
||||||
bool hasValueSemantics =
|
bool hasValueSemantics =
|
||||||
annotation->argAnnotations.value()[argIndex].hasValueSemantics;
|
annotation->argAnnotations.value()[argIndex].hasValueSemantics;
|
||||||
|
@ -561,10 +566,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
// the C API constructor, when we want the "we know we have 0 sizes"
|
// the C API constructor, when we want the "we know we have 0 sizes"
|
||||||
// case. So use a dummy data pointer.
|
// case. So use a dummy data pointer.
|
||||||
int64_t dummy;
|
int64_t dummy;
|
||||||
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
int64_t* shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||||
if (hasValueSemantics) {
|
if (hasValueSemantics) {
|
||||||
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
|
typeBound = torchMlirTorchValueTensorTypeGet(
|
||||||
shapeData, dtype);
|
context, shape.size(), shapeData, dtype);
|
||||||
} else {
|
} else {
|
||||||
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
||||||
context, shape.size(), shapeData, dtype);
|
context, shape.size(), shapeData, dtype);
|
||||||
|
@ -592,10 +597,9 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
MlirValue torch_mlir::importIValue(
|
||||||
MlirContext context,
|
c10::IValue ivalue, MlirBlock block, MlirContext context,
|
||||||
ClassAnnotator &annotator,
|
ClassAnnotator& annotator, const ImportOptions& importOptions) {
|
||||||
const ImportOptions &importOptions) {
|
|
||||||
// When debugging module importing, it can be useful to dump as so:
|
// When debugging module importing, it can be useful to dump as so:
|
||||||
// if (ivalue.isModule())
|
// if (ivalue.isModule())
|
||||||
// ivalue.toModule().dump(true, false, false);
|
// ivalue.toModule().dump(true, false, false);
|
|
@ -25,9 +25,9 @@ namespace torch_mlir {
|
||||||
|
|
||||||
/// Main entry-point for importing torch IValue's .
|
/// Main entry-point for importing torch IValue's .
|
||||||
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
||||||
MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
|
MlirValue importIValue(
|
||||||
ClassAnnotator &annotator,
|
c10::IValue ivalue, MlirBlock block, MlirContext context,
|
||||||
const ImportOptions &importOptions);
|
ClassAnnotator& annotator, const ImportOptions& importOptions);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -22,92 +22,92 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
inline MlirStringRef toMlirStringRef(const std::string &s) {
|
inline MlirStringRef toMlirStringRef(const std::string& s) {
|
||||||
return mlirStringRefCreate(s.data(), s.size());
|
return mlirStringRefCreate(s.data(), s.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline MlirStringRef toMlirStringRef(const char *s) {
|
inline MlirStringRef toMlirStringRef(const char* s) {
|
||||||
return mlirStringRefCreate(s, std::strlen(s));
|
return mlirStringRefCreate(s, std::strlen(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
|
inline MlirNamedAttribute
|
||||||
MlirAttribute attr) {
|
toMlirNamedAttribute(const char* s, MlirAttribute attr) {
|
||||||
MlirContext context = mlirAttributeGetContext(attr);
|
MlirContext context = mlirAttributeGetContext(attr);
|
||||||
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
|
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
|
||||||
return mlirNamedAttributeGet(ident, attr);
|
return mlirNamedAttributeGet(ident, attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void addToMlirOperationState(
|
||||||
MlirNamedAttribute namedAttr) {
|
MlirOperationState& state, MlirNamedAttribute namedAttr) {
|
||||||
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
|
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void
|
||||||
MlirRegion region) {
|
addToMlirOperationState(MlirOperationState& state, MlirRegion region) {
|
||||||
mlirOperationStateAddOwnedRegions(&state, 1, ®ion);
|
mlirOperationStateAddOwnedRegions(&state, 1, ®ion);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void
|
||||||
MlirValue value) {
|
addToMlirOperationState(MlirOperationState& state, MlirValue value) {
|
||||||
mlirOperationStateAddOperands(&state, 1, &value);
|
mlirOperationStateAddOperands(&state, 1, &value);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void addToMlirOperationState(
|
||||||
const std::vector<MlirValue> &values) {
|
MlirOperationState& state, const std::vector<MlirValue>& values) {
|
||||||
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void addToMlirOperationState(
|
||||||
c10::ArrayRef<MlirValue> values) {
|
MlirOperationState& state, c10::ArrayRef<MlirValue> values) {
|
||||||
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void
|
||||||
MlirType resultType) {
|
addToMlirOperationState(MlirOperationState& state, MlirType resultType) {
|
||||||
mlirOperationStateAddResults(&state, 1, &resultType);
|
mlirOperationStateAddResults(&state, 1, &resultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void addToMlirOperationState(
|
||||||
const std::vector<MlirType> &resultTypes) {
|
MlirOperationState& state, const std::vector<MlirType>& resultTypes) {
|
||||||
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state,
|
inline void addToMlirOperationState(
|
||||||
c10::ArrayRef<MlirType> resultTypes) {
|
MlirOperationState& state, c10::ArrayRef<MlirType> resultTypes) {
|
||||||
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void addToMlirOperationState(MlirOperationState &state, c10::optional<T> o) {
|
void addToMlirOperationState(MlirOperationState& state, c10::optional<T> o) {
|
||||||
if (o.has_value()) {
|
if (o.has_value()) {
|
||||||
addToMlirOperationState(state, o.value());
|
addToMlirOperationState(state, o.value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(MlirOperationState &state) {}
|
inline void addToMlirOperationState(MlirOperationState& state) {}
|
||||||
|
|
||||||
template <typename T, typename U, typename... Ts>
|
template <typename T, typename U, typename... Ts>
|
||||||
void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u,
|
void addToMlirOperationState(
|
||||||
Ts &&...ts) {
|
MlirOperationState& state, T&& t, U&& u, Ts&&... ts) {
|
||||||
addToMlirOperationState(state, std::forward<T>(t));
|
addToMlirOperationState(state, std::forward<T>(t));
|
||||||
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
|
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
MlirOperation createMlirOperation(std::string name, MlirLocation loc,
|
MlirOperation
|
||||||
Ts &&...ts) {
|
createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) {
|
||||||
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
|
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
|
||||||
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
||||||
return mlirOperationCreate(&state);
|
return mlirOperationCreate(&state);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
|
MlirOperation createMlirOperationAtEnd(
|
||||||
MlirLocation loc, Ts &&...ts) {
|
MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) {
|
||||||
MlirOperation operation =
|
MlirOperation operation =
|
||||||
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
|
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
|
||||||
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
operation);
|
block, mlirBlockGetTerminator(block), operation);
|
||||||
return operation;
|
return operation;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
static py::object getMlirIrClass(const char *className) {
|
static py::object getMlirIrClass(const char* className) {
|
||||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) {
|
||||||
return contextObj;
|
return contextObj;
|
||||||
}
|
}
|
||||||
|
|
||||||
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
|
static MlirContext castPythonObjectToMlirContext(py::object& contextObj) {
|
||||||
assert(!contextObj.is_none() && "context cannot be None");
|
assert(!contextObj.is_none() && "context cannot be None");
|
||||||
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||||
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
||||||
|
@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
||||||
<< ": ";
|
<< ": ";
|
||||||
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
|
auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) {
|
||||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
auto* ssp = static_cast<std::stringstream*>(stringCallbackUserData);
|
||||||
ssp->write(s.data, s.length);
|
ssp->write(s.data, s.length);
|
||||||
};
|
};
|
||||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void*>(&ss));
|
||||||
// Use pybind11's print:
|
// Use pybind11's print:
|
||||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||||
py::print(ss.str(),
|
py::print(
|
||||||
py::arg("file") = py::module_::import("sys").attr("stderr"));
|
ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||||
|
@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
||||||
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||||
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||||
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||||
void *) -> MlirLogicalResult {
|
void*) -> MlirLogicalResult {
|
||||||
printDiagnostic(diagnostic);
|
printDiagnostic(diagnostic);
|
||||||
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
||||||
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
||||||
|
@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||||
return mlirLogicalResultSuccess();
|
return mlirLogicalResultSuccess();
|
||||||
};
|
};
|
||||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||||
context, diagnosticHandler, nullptr, [](void *) { return; });
|
context, diagnosticHandler, nullptr, [](void*) { return; });
|
||||||
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
||||||
// of this context.
|
// of this context.
|
||||||
(void)id;
|
(void)id;
|
||||||
|
@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr ModuleBuilder::importFunction(
|
||||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) {
|
||||||
py::object maybeImportOptions) {
|
|
||||||
ImportOptions importOptions;
|
ImportOptions importOptions;
|
||||||
if (!maybeImportOptions.is_none()) {
|
if (!maybeImportOptions.is_none()) {
|
||||||
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||||
}
|
}
|
||||||
MlirBlock block = getBodyBlock();
|
MlirBlock block = getBodyBlock();
|
||||||
MlirOperation terminator = this->terminator;
|
MlirOperation terminator = this->terminator;
|
||||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
MlirOperation func = importJitFunctionAsFuncOp(
|
||||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
context, function.function_,
|
||||||
|
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
void ModuleBuilder::importModule(
|
||||||
py::object maybeClassAnnotator,
|
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||||
py::object maybeImportOptions) {
|
py::object maybeImportOptions) {
|
||||||
ClassAnnotator dummyAnnotator;
|
ClassAnnotator dummyAnnotator;
|
||||||
ClassAnnotator *classAnnotator = &dummyAnnotator;
|
ClassAnnotator* classAnnotator = &dummyAnnotator;
|
||||||
if (!maybeClassAnnotator.is_none()) {
|
if (!maybeClassAnnotator.is_none()) {
|
||||||
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
|
classAnnotator = py::cast<ClassAnnotator*>(maybeClassAnnotator);
|
||||||
}
|
}
|
||||||
ImportOptions importOptions;
|
ImportOptions importOptions;
|
||||||
if (!maybeImportOptions.is_none()) {
|
if (!maybeImportOptions.is_none()) {
|
||||||
|
@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
||||||
// precise `torch.class_type` names.
|
// precise `torch.class_type` names.
|
||||||
//
|
//
|
||||||
// This name is not semantically load-bearing!!!
|
// This name is not semantically load-bearing!!!
|
||||||
auto &name = *jitModule.type()->name();
|
auto& name = *jitModule.type()->name();
|
||||||
auto debugModuleNameAttr = mlirStringAttrGet(
|
auto debugModuleNameAttr = mlirStringAttrGet(
|
||||||
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
|
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
|
||||||
mlirOperationSetAttributeByName(mlirModuleGetOperation(module),
|
mlirOperationSetAttributeByName(
|
||||||
toMlirStringRef("torch.debug_module_name"),
|
mlirModuleGetOperation(module),
|
||||||
debugModuleNameAttr);
|
toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr);
|
||||||
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
|
importIValue(
|
||||||
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
jitModule._ivalue(), mlirModuleGetBody(module),
|
||||||
|
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||||
|
@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() {
|
||||||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleBuilder::bind(py::module &m) {
|
void ModuleBuilder::bind(py::module& m) {
|
||||||
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
||||||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||||
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
.def(
|
||||||
py::arg("importOptions") = py::none())
|
"import_function", &ModuleBuilder::importFunction,
|
||||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
py::arg("function"), py::arg("importOptions") = py::none())
|
||||||
py::arg("classAnnotator") = py::none(),
|
.def(
|
||||||
py::arg("importOptions") = py::none());
|
"import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||||
|
py::arg("classAnnotator") = py::none(),
|
||||||
|
py::arg("importOptions") = py::none());
|
||||||
}
|
}
|
|
@ -29,7 +29,7 @@ public:
|
||||||
ModuleBuilder(pybind11::object contextObj);
|
ModuleBuilder(pybind11::object contextObj);
|
||||||
|
|
||||||
/// Creates Python bindings for the class.
|
/// Creates Python bindings for the class.
|
||||||
static void bind(pybind11::module &m);
|
static void bind(pybind11::module& m);
|
||||||
|
|
||||||
pybind11::object getContextObj() { return contextObj; }
|
pybind11::object getContextObj() { return contextObj; }
|
||||||
pybind11::object getModuleObj() { return moduleObj; }
|
pybind11::object getModuleObj() { return moduleObj; }
|
||||||
|
@ -38,16 +38,15 @@ public:
|
||||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||||
// Just a bit of naming cruft.
|
// Just a bit of naming cruft.
|
||||||
// Returns the same function, making it suitable as a nested decorator.
|
// Returns the same function, making it suitable as a nested decorator.
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr importFunction(
|
||||||
importFunction(torch::jit::StrongFunctionPtr function,
|
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions);
|
||||||
py::object maybeImportOptions);
|
|
||||||
|
|
||||||
// Imports a torch::jit::Module into the current module, using the
|
// Imports a torch::jit::Module into the current module, using the
|
||||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||||
// a ClassAnnotator.
|
// a ClassAnnotator.
|
||||||
void importModule(torch::jit::Module jitModule,
|
void importModule(
|
||||||
py::object maybeClassAnnotator,
|
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||||
py::object maybeImportOptions);
|
py::object maybeImportOptions);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock getBodyBlock();
|
MlirBlock getBodyBlock();
|
|
@ -33,41 +33,42 @@ class NodeImporter {
|
||||||
public:
|
public:
|
||||||
NodeImporter(MlirContext context) : context(context) {}
|
NodeImporter(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
void importNode(Node *node, MlirBlock appendToBlock,
|
void importNode(
|
||||||
const ImportOptions &importOptions = {});
|
Node* node, MlirBlock appendToBlock,
|
||||||
|
const ImportOptions& importOptions = {});
|
||||||
MlirBlock importBlock(
|
MlirBlock importBlock(
|
||||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
Block* jitBlock, CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions& importOptions = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock
|
MlirBlock createBlockFor(
|
||||||
createBlockFor(Block *jitBlock,
|
Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
const ImportOptions& importOptions = {});
|
||||||
const ImportOptions &importOptions = {});
|
void mapValue(Value* jitValue, MlirValue value);
|
||||||
void mapValue(Value *jitValue, MlirValue value);
|
void mapResults(Node* node, MlirOperation operation);
|
||||||
void mapResults(Node *node, MlirOperation operation);
|
MlirValue lookupMappedValue(Value* jitValue);
|
||||||
MlirValue lookupMappedValue(Value *jitValue);
|
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value*> values);
|
||||||
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
|
|
||||||
|
|
||||||
MlirContext context;
|
MlirContext context;
|
||||||
std::unordered_map<Value *, MlirValue> valueMap;
|
std::unordered_map<Value*, MlirValue> valueMap;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
using InputsTransformFn =
|
using InputsTransformFn =
|
||||||
std::function<std::vector<MlirValue>(std::vector<MlirValue> &)>;
|
std::function<std::vector<MlirValue>(std::vector<MlirValue>&)>;
|
||||||
|
|
||||||
// The inputs of `DictConstruct` in TorchScript IR are in the order
|
// The inputs of `DictConstruct` in TorchScript IR are in the order
|
||||||
// like k0, v0, k1, v1. Rearrange them to put the key operands together and
|
// like k0, v0, k1, v1. Rearrange them to put the key operands together and
|
||||||
// then the value operands like k0, k1,v0, v1. This is the expected format by
|
// then the value operands like k0, k1,v0, v1. This is the expected format by
|
||||||
// the corresponding MLIR op.
|
// the corresponding MLIR op.
|
||||||
static std::vector<MlirValue>
|
static std::vector<MlirValue>
|
||||||
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
rearrangeDictConstructInputs(std::vector<MlirValue>& inputs) {
|
||||||
if (inputs.empty())
|
if (inputs.empty())
|
||||||
return inputs;
|
return inputs;
|
||||||
assert(inputs.size() % 2 == 0 &&
|
assert(
|
||||||
"DictConstruct must have even number of operands");
|
inputs.size() % 2 == 0 &&
|
||||||
|
"DictConstruct must have even number of operands");
|
||||||
|
|
||||||
std::vector<MlirValue> rearranged;
|
std::vector<MlirValue> rearranged;
|
||||||
std::vector<MlirValue> values;
|
std::vector<MlirValue> values;
|
||||||
|
@ -79,12 +80,12 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
||||||
return rearranged;
|
return rearranged;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
void NodeImporter::importNode(
|
||||||
const ImportOptions &importOptions) {
|
Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) {
|
||||||
MlirLocation loc = getMlirLocationFromNode(context, node);
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
||||||
auto kind = node->kind();
|
auto kind = node->kind();
|
||||||
|
|
||||||
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
|
auto createAndMapTrivialNode = [&](Node* node, const std::string& opName,
|
||||||
InputsTransformFn t) {
|
InputsTransformFn t) {
|
||||||
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
|
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
@ -95,7 +96,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto createAndMapNodeWithAttribute =
|
auto createAndMapNodeWithAttribute =
|
||||||
[&](Node *node, const std::string &opName, const std::string &attrName,
|
[&](Node* node, const std::string& opName, const std::string& attrName,
|
||||||
MlirAttribute attr) {
|
MlirAttribute attr) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, opName, loc,
|
appendToBlock, opName, loc,
|
||||||
|
@ -132,27 +133,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
// ListConstruct and DictConstruct too.
|
// ListConstruct and DictConstruct too.
|
||||||
auto containedTypes = c10::fmap(
|
auto containedTypes = c10::fmap(
|
||||||
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
|
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
|
||||||
[&](const c10::TypePtr &t) {
|
[&](const c10::TypePtr& t) {
|
||||||
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
|
MlirType type = getMlirTypeFromTorchType(loc, t, importOptions);
|
||||||
if (mlirTypeIsNull(type)) {
|
if (mlirTypeIsNull(type)) {
|
||||||
throw mlir_diagnostic_emitted();
|
throw mlir_diagnostic_emitted();
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
});
|
});
|
||||||
createAndMapTrivialNode(node,
|
createAndMapTrivialNode(
|
||||||
"torch.prim." + std::string(kind.toUnqualString()),
|
node, "torch.prim." + std::string(kind.toUnqualString()),
|
||||||
[&](std::vector<MlirValue> &inputs) {
|
[&](std::vector<MlirValue>& inputs) {
|
||||||
assert(containedTypes.size() == inputs.size());
|
assert(containedTypes.size() == inputs.size());
|
||||||
return adjustStaticInformationForValues(
|
return adjustStaticInformationForValues(
|
||||||
appendToBlock, loc, inputs, containedTypes,
|
appendToBlock, loc, inputs, containedTypes,
|
||||||
/*userAllowsRefinement=*/true);
|
/*userAllowsRefinement=*/true);
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case c10::prim::DictConstruct: {
|
case c10::prim::DictConstruct: {
|
||||||
createAndMapTrivialNode(node,
|
createAndMapTrivialNode(
|
||||||
"torch.prim." + std::string(kind.toUnqualString()),
|
node, "torch.prim." + std::string(kind.toUnqualString()),
|
||||||
rearrangeDictConstructInputs);
|
rearrangeDictConstructInputs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case c10::prim::Load:
|
case c10::prim::Load:
|
||||||
|
@ -170,32 +171,34 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
auto output = node->output();
|
auto output = node->output();
|
||||||
MlirOperation op;
|
MlirOperation op;
|
||||||
if (output->type()->cast<c10::NoneType>()) {
|
if (output->type()->cast<c10::NoneType>()) {
|
||||||
op = createMlirOperation("torch.constant.none", loc,
|
op = createMlirOperation(
|
||||||
torchMlirTorchNoneTypeGet(context));
|
"torch.constant.none", loc, torchMlirTorchNoneTypeGet(context));
|
||||||
} else if (output->type()->cast<c10::BoolType>()) {
|
} else if (output->type()->cast<c10::BoolType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
|
"value",
|
||||||
c10::attr::value)))));
|
mlirBoolAttrGet(
|
||||||
|
context, static_cast<bool>(node->i(c10::attr::value)))));
|
||||||
} else if (output->type()->cast<c10::IntType>()) {
|
} else if (output->type()->cast<c10::IntType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.int", loc,
|
"torch.constant.int", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute(
|
||||||
importAttribute(loc, node, c10::attr::value)));
|
"value", importAttribute(loc, node, c10::attr::value)));
|
||||||
} else if (output->type()->cast<c10::FloatType>()) {
|
} else if (output->type()->cast<c10::FloatType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.float", loc,
|
"torch.constant.float", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute(
|
||||||
importAttribute(loc, node, c10::attr::value)));
|
"value", importAttribute(loc, node, c10::attr::value)));
|
||||||
} else if (output->type()->cast<c10::StringType>()) {
|
} else if (output->type()->cast<c10::StringType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
"value",
|
||||||
c10::attr::value)))));
|
mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(node->s(c10::attr::value)))));
|
||||||
} else if (output->type()->cast<c10::TensorType>()) {
|
} else if (output->type()->cast<c10::TensorType>()) {
|
||||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
|
@ -214,24 +217,26 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
"torch.constant.device", loc,
|
"torch.constant.device", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
"value",
|
||||||
c10::attr::value)))));
|
mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(node->s(c10::attr::value)))));
|
||||||
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
||||||
torch::jit::Function *function = functionType->function();
|
torch::jit::Function* function = functionType->function();
|
||||||
const std::string &symName = function->qualname().qualifiedName();
|
const std::string& symName = function->qualname().qualifiedName();
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"func.constant", loc,
|
"func.constant", loc,
|
||||||
getFunctionTypeFromSchema(context, function->getSchema(),
|
getFunctionTypeFromSchema(
|
||||||
importOptions),
|
context, function->getSchema(), importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||||
} else if (output->type()->cast<c10::ListType>() ||
|
} else if (
|
||||||
output->type()->cast<c10::TupleType>()) {
|
output->type()->cast<c10::ListType>() ||
|
||||||
|
output->type()->cast<c10::TupleType>()) {
|
||||||
ClassAnnotator dummyAnnotator;
|
ClassAnnotator dummyAnnotator;
|
||||||
MlirValue listOrTupleValue =
|
MlirValue listOrTupleValue = importIValue(
|
||||||
importIValue(node->ival(c10::attr::value), appendToBlock, context,
|
node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator,
|
||||||
dummyAnnotator, importOptions);
|
importOptions);
|
||||||
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
|
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
|
||||||
return; // Early return, since `importIValue` already added op to block.
|
return; // Early return, since `importIValue` already added op to block.
|
||||||
} else {
|
} else {
|
||||||
|
@ -259,19 +264,20 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
std::vector<MlirType> terminatorOperandTypes = {
|
std::vector<MlirType> terminatorOperandTypes = {
|
||||||
torchMlirTorchBoolTypeGet(context)};
|
torchMlirTorchBoolTypeGet(context)};
|
||||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
terminatorOperandTypes.insert(
|
||||||
resultTypes.begin(), resultTypes.end());
|
terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end());
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.Loop.condition", loc,
|
appendToBlock, "torch.prim.Loop.condition", loc,
|
||||||
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
adjustStaticInformationForValues(
|
||||||
terminatorOperandTypes,
|
appendToBlock, loc, yieldedValues, terminatorOperandTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(
|
||||||
mlirOperationGetRegion(operation, 0),
|
mlirOperationGetRegion(operation, 0),
|
||||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
importBlock(
|
||||||
|
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,25 +292,27 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.If.yield", loc,
|
appendToBlock, "torch.prim.If.yield", loc,
|
||||||
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
adjustStaticInformationForValues(
|
||||||
resultTypes,
|
appendToBlock, loc, yieldedValues, resultTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(
|
||||||
mlirOperationGetRegion(operation, 0),
|
mlirOperationGetRegion(operation, 0),
|
||||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
importBlock(
|
||||||
|
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(
|
||||||
mlirOperationGetRegion(operation, 1),
|
mlirOperationGetRegion(operation, 1),
|
||||||
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
importBlock(
|
||||||
|
node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::CallMethod) {
|
if (kind == c10::prim::CallMethod) {
|
||||||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||||
auto methodName = node->s(c10::attr::name);
|
auto methodName = node->s(c10::attr::name);
|
||||||
torch::jit::Function *function = classType->findMethod(methodName);
|
torch::jit::Function* function = classType->findMethod(methodName);
|
||||||
MlirType calleeType =
|
MlirType calleeType = getFunctionTypeFromSchema(
|
||||||
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
|
context, function->getSchema(), importOptions);
|
||||||
std::vector<MlirType> expectedTypes;
|
std::vector<MlirType> expectedTypes;
|
||||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||||
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||||
|
@ -315,17 +323,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
adjustStaticInformationForValues(
|
adjustStaticInformationForValues(
|
||||||
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
||||||
expectedTypes, /*userAllowsRefinement=*/false),
|
expectedTypes, /*userAllowsRefinement=*/false),
|
||||||
toMlirNamedAttribute("name",
|
toMlirNamedAttribute(
|
||||||
importAttribute(loc, node, c10::attr::name)));
|
"name", importAttribute(loc, node, c10::attr::name)));
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::CallFunction) {
|
if (kind == c10::prim::CallFunction) {
|
||||||
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||||
torch::jit::Block *calleeEntryBlock =
|
torch::jit::Block* calleeEntryBlock =
|
||||||
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
|
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
|
||||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value* v) {
|
||||||
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
|
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
|
||||||
});
|
});
|
||||||
std::string functionName = node->input(0)->node()->s(c10::attr::name);
|
std::string functionName = node->input(0)->node()->s(c10::attr::name);
|
||||||
|
@ -340,9 +348,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
// promoted result dtype for a PyTorch computation. Here we turn the call to
|
// promoted result dtype for a PyTorch computation. Here we turn the call to
|
||||||
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
|
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
|
||||||
if (functionName == "__torch_mlir_internal_promote_dtypes") {
|
if (functionName == "__torch_mlir_internal_promote_dtypes") {
|
||||||
operation =
|
operation = createMlirOperationAtEnd(
|
||||||
createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc,
|
appendToBlock, "torch.promote_dtypes", loc, resultTypes,
|
||||||
resultTypes, adjustedFuncArgs);
|
adjustedFuncArgs);
|
||||||
} else {
|
} else {
|
||||||
operation = createMlirOperationAtEnd(
|
operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, "func.call_indirect", loc, resultTypes,
|
appendToBlock, "func.call_indirect", loc, resultTypes,
|
||||||
|
@ -362,22 +370,22 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::importBlock(
|
MlirBlock NodeImporter::importBlock(
|
||||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
Block* jitBlock, CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions& importOptions) {
|
||||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
||||||
for (Node *node : jitBlock->nodes()) {
|
for (Node* node : jitBlock->nodes()) {
|
||||||
importNode(node, block, importOptions);
|
importNode(node, block, importOptions);
|
||||||
}
|
}
|
||||||
Node *returnNode = jitBlock->return_node();
|
Node* returnNode = jitBlock->return_node();
|
||||||
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::createBlockFor(
|
MlirBlock NodeImporter::createBlockFor(
|
||||||
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions& importOptions) {
|
||||||
Node *paramNode = jitBlock->param_node();
|
Node* paramNode = jitBlock->param_node();
|
||||||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||||
std::vector<MlirType> paramNodeTypes =
|
std::vector<MlirType> paramNodeTypes =
|
||||||
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
|
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
|
||||||
|
@ -386,11 +394,11 @@ MlirBlock NodeImporter::createBlockFor(
|
||||||
else
|
else
|
||||||
assert(blockArgTypes->size() == paramNodeTypes.size());
|
assert(blockArgTypes->size() == paramNodeTypes.size());
|
||||||
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
||||||
MlirBlock block =
|
MlirBlock block = mlirBlockCreate(
|
||||||
mlirBlockCreate(blockArgTypes.value().size(),
|
blockArgTypes.value().size(), blockArgTypes.value().data(),
|
||||||
blockArgTypes.value().data(), blockArgLocs.data());
|
blockArgLocs.data());
|
||||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||||
Value *jitValue = paramNode->outputs()[i];
|
Value* jitValue = paramNode->outputs()[i];
|
||||||
MlirValue value = mlirBlockGetArgument(block, i);
|
MlirValue value = mlirBlockGetArgument(block, i);
|
||||||
MlirValue adjusted = adjustStaticInformationForValues(
|
MlirValue adjusted = adjustStaticInformationForValues(
|
||||||
block, loc, {value}, {paramNodeTypes[i]},
|
block, loc, {value}, {paramNodeTypes[i]},
|
||||||
|
@ -400,39 +408,40 @@ MlirBlock NodeImporter::createBlockFor(
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeImporter::mapValue(Value *jitValue, MlirValue value) {
|
void NodeImporter::mapValue(Value* jitValue, MlirValue value) {
|
||||||
auto it = valueMap.find(jitValue);
|
auto it = valueMap.find(jitValue);
|
||||||
(void)it;
|
(void)it;
|
||||||
assert(it == valueMap.end() && "jitValue has already been mapped");
|
assert(it == valueMap.end() && "jitValue has already been mapped");
|
||||||
valueMap[jitValue] = value;
|
valueMap[jitValue] = value;
|
||||||
}
|
}
|
||||||
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
|
void NodeImporter::mapResults(Node* node, MlirOperation operation) {
|
||||||
assert(node->outputs().size() ==
|
assert(
|
||||||
(size_t)mlirOperationGetNumResults(operation));
|
node->outputs().size() == (size_t)mlirOperationGetNumResults(operation));
|
||||||
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
||||||
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
|
MlirValue NodeImporter::lookupMappedValue(Value* jitValue) {
|
||||||
auto it = valueMap.find(jitValue);
|
auto it = valueMap.find(jitValue);
|
||||||
assert(it != valueMap.end() &&
|
assert(
|
||||||
"trying to get mapping for jitValue that is not mapped yet!");
|
it != valueMap.end() &&
|
||||||
|
"trying to get mapping for jitValue that is not mapped yet!");
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::vector<MlirValue>
|
std::vector<MlirValue>
|
||||||
NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
NodeImporter::lookupMappedValues(c10::ArrayRef<Value*> values) {
|
||||||
std::vector<MlirValue> ret;
|
std::vector<MlirValue> ret;
|
||||||
for (Value *value : values) {
|
for (Value* value : values) {
|
||||||
ret.push_back(lookupMappedValue(value));
|
ret.push_back(lookupMappedValue(value));
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock
|
MlirBlock torch_mlir::importBlock(
|
||||||
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator,
|
||||||
CreateTerminatorFn createTerminator,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
const ImportOptions& importOptions) {
|
||||||
const ImportOptions &importOptions) {
|
|
||||||
NodeImporter importer(context);
|
NodeImporter importer(context);
|
||||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions);
|
return importer.importBlock(
|
||||||
|
jitBlock, createTerminator, blockArgTypes, importOptions);
|
||||||
}
|
}
|
|
@ -37,10 +37,10 @@ using CreateTerminatorFn =
|
||||||
/// adjust the types to the block argument types.
|
/// adjust the types to the block argument types.
|
||||||
/// TODO: Formalize what type conversions are allowed here.
|
/// TODO: Formalize what type conversions are allowed here.
|
||||||
MlirBlock importBlock(
|
MlirBlock importBlock(
|
||||||
MlirContext context, torch::jit::Block *jitBlock,
|
MlirContext context, torch::jit::Block* jitBlock,
|
||||||
CreateTerminatorFn createTerminator,
|
CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions& importOptions = {});
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -26,8 +26,8 @@
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
|
static MlirType getMlirTypeForTorchScalarTypeRaw(
|
||||||
c10::ScalarType scalarType) {
|
MlirContext context, c10::ScalarType scalarType) {
|
||||||
using c10::ScalarType;
|
using c10::ScalarType;
|
||||||
switch (scalarType) {
|
switch (scalarType) {
|
||||||
case ScalarType::Byte:
|
case ScalarType::Byte:
|
||||||
|
@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
|
MlirType torch_mlir::getMlirTypeForTorchScalarType(
|
||||||
c10::ScalarType scalarType) {
|
MlirLocation loc, c10::ScalarType scalarType) {
|
||||||
auto type =
|
auto type =
|
||||||
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
|
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
|
||||||
if (mlirTypeIsNull(type)) {
|
if (mlirTypeIsNull(type)) {
|
||||||
|
@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
|
||||||
// There is no generic way to import custom classes (or their types), so we
|
// There is no generic way to import custom classes (or their types), so we
|
||||||
// have to name match them here (and the relevant code in the ivalue
|
// have to name match them here (and the relevant code in the ivalue
|
||||||
// importer) and create special IR constructs for them.
|
// importer) and create special IR constructs for them.
|
||||||
static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
static MlirType mapCustomClassType(
|
||||||
const c10::ClassTypePtr &classType) {
|
MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) {
|
||||||
// If the type is unnamed, it cannot be a custom class.
|
// If the type is unnamed, it cannot be a custom class.
|
||||||
if (!classType->name().has_value()) {
|
if (!classType->name().has_value()) {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
|
@ -126,10 +126,9 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
||||||
throw mlir_diagnostic_emitted();
|
throw mlir_diagnostic_emitted();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType
|
MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
MlirLocation loc, const c10::TypePtr& torchType,
|
||||||
const c10::TypePtr &torchType,
|
const ImportOptions& importOptions) {
|
||||||
const ImportOptions &importOptions) {
|
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
using c10::TypeKind;
|
using c10::TypeKind;
|
||||||
auto kind = torchType->kind();
|
auto kind = torchType->kind();
|
||||||
|
@ -141,10 +140,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
: torchMlirTorchNonValueTensorTypeGet;
|
: torchMlirTorchNonValueTensorTypeGet;
|
||||||
|
|
||||||
if (importOptions.ignoreExistingTensorShapesAndDtypes) {
|
if (importOptions.ignoreExistingTensorShapesAndDtypes) {
|
||||||
return getMlirTensorType(context,
|
return getMlirTensorType(
|
||||||
/*numSizes=*/-1,
|
context,
|
||||||
/*optionalSizes=*/nullptr,
|
/*numSizes=*/-1,
|
||||||
/*optionalDtype=*/{nullptr});
|
/*optionalSizes=*/nullptr,
|
||||||
|
/*optionalDtype=*/{nullptr});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Element type.
|
// Element type.
|
||||||
|
@ -156,17 +156,18 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
// Sizes.
|
// Sizes.
|
||||||
auto &sizes = tensorType->symbolic_sizes();
|
auto& sizes = tensorType->symbolic_sizes();
|
||||||
if (!sizes.rank()) {
|
if (!sizes.rank()) {
|
||||||
// Unranked.
|
// Unranked.
|
||||||
return getMlirTensorType(context,
|
return getMlirTensorType(
|
||||||
/*numSizes=*/-1,
|
context,
|
||||||
/*optionalSizes=*/nullptr,
|
/*numSizes=*/-1,
|
||||||
/*optionalDtype=*/
|
/*optionalSizes=*/nullptr,
|
||||||
elementType);
|
/*optionalDtype=*/
|
||||||
|
elementType);
|
||||||
}
|
}
|
||||||
// Ranked with possibly dynamic dims.
|
// Ranked with possibly dynamic dims.
|
||||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
auto& symbolicShape = tensorType->symbolic_sizes();
|
||||||
std::vector<int64_t> dims;
|
std::vector<int64_t> dims;
|
||||||
dims.resize(*sizes.rank());
|
dims.resize(*sizes.rank());
|
||||||
for (size_t i = 0; i < dims.size(); ++i) {
|
for (size_t i = 0; i < dims.size(); ++i) {
|
||||||
|
@ -179,11 +180,12 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
// the C API constructor, when we want the "we know we have 0 sizes"
|
// the C API constructor, when we want the "we know we have 0 sizes"
|
||||||
// case. So use a dummy data pointer.
|
// case. So use a dummy data pointer.
|
||||||
int64_t dummy;
|
int64_t dummy;
|
||||||
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
|
int64_t* dimsData = dims.size() == 0 ? &dummy : dims.data();
|
||||||
return getMlirTensorType(context, dims.size(),
|
return getMlirTensorType(
|
||||||
/*optionalSizes=*/dimsData,
|
context, dims.size(),
|
||||||
/*optionalDtype=*/
|
/*optionalSizes=*/dimsData,
|
||||||
elementType);
|
/*optionalDtype=*/
|
||||||
|
elementType);
|
||||||
}
|
}
|
||||||
case TypeKind::IntType: {
|
case TypeKind::IntType: {
|
||||||
return torchMlirTorchIntTypeGet(context);
|
return torchMlirTorchIntTypeGet(context);
|
||||||
|
@ -207,22 +209,22 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
}
|
}
|
||||||
case TypeKind::TupleType: {
|
case TypeKind::TupleType: {
|
||||||
std::vector<MlirType> containedTypes;
|
std::vector<MlirType> containedTypes;
|
||||||
for (const c10::TypePtr &type :
|
for (const c10::TypePtr& type :
|
||||||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||||
containedTypes.push_back(
|
containedTypes.push_back(
|
||||||
getMlirTypeFromTorchType(loc, type, importOptions));
|
getMlirTypeFromTorchType(loc, type, importOptions));
|
||||||
}
|
}
|
||||||
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
return torchMlirTorchTupleTypeGet(
|
||||||
containedTypes.data());
|
context, containedTypes.size(), containedTypes.data());
|
||||||
}
|
}
|
||||||
case TypeKind::UnionType: {
|
case TypeKind::UnionType: {
|
||||||
std::vector<MlirType> containedTypes;
|
std::vector<MlirType> containedTypes;
|
||||||
for (const c10::TypePtr &type :
|
for (const c10::TypePtr& type :
|
||||||
torchType->cast<c10::UnionType>()->containedTypes()) {
|
torchType->cast<c10::UnionType>()->containedTypes()) {
|
||||||
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
||||||
}
|
}
|
||||||
return torchMlirTorchUnionTypeGet(context, containedTypes.size(),
|
return torchMlirTorchUnionTypeGet(
|
||||||
containedTypes.data());
|
context, containedTypes.size(), containedTypes.data());
|
||||||
}
|
}
|
||||||
case TypeKind::ListType: {
|
case TypeKind::ListType: {
|
||||||
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
||||||
|
@ -242,7 +244,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
return torchMlirTorchAnyTypeGet(context);
|
return torchMlirTorchAnyTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::ClassType: {
|
case TypeKind::ClassType: {
|
||||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
const c10::ClassTypePtr& classType = torchType->cast<c10::ClassType>();
|
||||||
MlirType customClassType = mapCustomClassType(context, loc, classType);
|
MlirType customClassType = mapCustomClassType(context, loc, classType);
|
||||||
if (!mlirTypeIsNull(customClassType)) {
|
if (!mlirTypeIsNull(customClassType)) {
|
||||||
return customClassType;
|
return customClassType;
|
||||||
|
@ -266,12 +268,11 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType
|
MlirType torch_mlir::getFunctionTypeFromSchema(
|
||||||
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
MlirContext context, const c10::FunctionSchema& schema,
|
||||||
const c10::FunctionSchema &schema,
|
const ImportOptions& importOptions) {
|
||||||
const ImportOptions &importOptions) {
|
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
auto mapType = [&](const c10::TypePtr &torchType) {
|
auto mapType = [&](const c10::TypePtr& torchType) {
|
||||||
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
|
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
|
||||||
if (mlirTypeIsNull(type)) {
|
if (mlirTypeIsNull(type)) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
|
@ -283,17 +284,20 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<MlirType> inputTypes =
|
std::vector<MlirType> inputTypes =
|
||||||
c10::fmap(schema.arguments(),
|
c10::fmap(schema.arguments(), [&](const c10::Argument& arg) {
|
||||||
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
return mapType(arg.type());
|
||||||
|
});
|
||||||
std::vector<MlirType> outputTypes =
|
std::vector<MlirType> outputTypes =
|
||||||
c10::fmap(schema.returns(),
|
c10::fmap(schema.returns(), [&](const c10::Argument& arg) {
|
||||||
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
return mapType(arg.type());
|
||||||
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
});
|
||||||
outputTypes.size(), outputTypes.data());
|
return mlirFunctionTypeGet(
|
||||||
|
context, inputTypes.size(), inputTypes.data(), outputTypes.size(),
|
||||||
|
outputTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
|
||||||
MlirLocation loc) {
|
at::Tensor tensor, MlirLocation loc) {
|
||||||
using at::ScalarType;
|
using at::ScalarType;
|
||||||
|
|
||||||
auto throwUnsupportedTensorError = [&]() {
|
auto throwUnsupportedTensorError = [&]() {
|
||||||
|
@ -308,8 +312,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
|
|
||||||
// The flat number of bytes throws an exception for tensors that are not
|
// The flat number of bytes throws an exception for tensors that are not
|
||||||
// dense and accessible as such.
|
// dense and accessible as such.
|
||||||
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
|
at::checkLayout(
|
||||||
c10::Layout::Strided);
|
at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided);
|
||||||
|
|
||||||
// Construct the ShapedType.
|
// Construct the ShapedType.
|
||||||
|
|
||||||
|
@ -334,47 +338,47 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
switch (tensor.scalar_type()) {
|
switch (tensor.scalar_type()) {
|
||||||
case ScalarType::Int:
|
case ScalarType::Int:
|
||||||
return mlirDenseElementsAttrInt32Get(
|
return mlirDenseElementsAttrInt32Get(
|
||||||
shapedType, numElements, static_cast<const int32_t *>(tensorData));
|
shapedType, numElements, static_cast<const int32_t*>(tensorData));
|
||||||
break;
|
break;
|
||||||
case ScalarType::Long:
|
case ScalarType::Long:
|
||||||
return mlirDenseElementsAttrInt64Get(
|
return mlirDenseElementsAttrInt64Get(
|
||||||
shapedType, numElements, static_cast<const int64_t *>(tensorData));
|
shapedType, numElements, static_cast<const int64_t*>(tensorData));
|
||||||
break;
|
break;
|
||||||
case ScalarType::Float:
|
case ScalarType::Float:
|
||||||
return mlirDenseElementsAttrFloatGet(
|
return mlirDenseElementsAttrFloatGet(
|
||||||
shapedType, numElements, static_cast<const float *>(tensorData));
|
shapedType, numElements, static_cast<const float*>(tensorData));
|
||||||
break;
|
break;
|
||||||
case ScalarType::Double:
|
case ScalarType::Double:
|
||||||
return mlirDenseElementsAttrDoubleGet(
|
return mlirDenseElementsAttrDoubleGet(
|
||||||
shapedType, numElements, static_cast<const double *>(tensorData));
|
shapedType, numElements, static_cast<const double*>(tensorData));
|
||||||
break;
|
break;
|
||||||
case ScalarType::Bool: {
|
case ScalarType::Bool: {
|
||||||
// TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed
|
// TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed
|
||||||
// upstream to take in a `const bool *` rather than a `const int *` to avoid
|
// upstream to take in a `const bool *` rather than a `const int *` to avoid
|
||||||
// the unnecessary copying into an array four times as large.
|
// the unnecessary copying into an array four times as large.
|
||||||
const int8_t *elements = static_cast<const int8_t *>(tensorData);
|
const int8_t* elements = static_cast<const int8_t*>(tensorData);
|
||||||
std::vector<int> tensorDataVector(elements, elements + numElements);
|
std::vector<int> tensorDataVector(elements, elements + numElements);
|
||||||
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
|
return mlirDenseElementsAttrBoolGet(
|
||||||
tensorDataVector.data());
|
shapedType, numElements, tensorDataVector.data());
|
||||||
} break;
|
} break;
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return mlirDenseElementsAttrInt8Get(
|
return mlirDenseElementsAttrInt8Get(
|
||||||
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
shapedType, numElements, static_cast<const int8_t*>(tensorData));
|
||||||
case ScalarType::QUInt8:
|
case ScalarType::QUInt8:
|
||||||
return mlirDenseElementsAttrUInt8Get(
|
return mlirDenseElementsAttrUInt8Get(
|
||||||
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
|
shapedType, numElements, static_cast<const uint8_t*>(tensorData));
|
||||||
case ScalarType::BFloat16:
|
case ScalarType::BFloat16:
|
||||||
return mlirDenseElementsAttrBFloat16Get(
|
return mlirDenseElementsAttrBFloat16Get(
|
||||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
shapedType, numElements, static_cast<const uint16_t*>(tensorData));
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return mlirDenseElementsAttrFloat16Get(
|
return mlirDenseElementsAttrFloat16Get(
|
||||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
shapedType, numElements, static_cast<const uint16_t*>(tensorData));
|
||||||
case ScalarType::Byte:
|
case ScalarType::Byte:
|
||||||
return mlirDenseElementsAttrUInt8Get(
|
return mlirDenseElementsAttrUInt8Get(
|
||||||
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
|
shapedType, numElements, static_cast<const uint8_t*>(tensorData));
|
||||||
case ScalarType::Char:
|
case ScalarType::Char:
|
||||||
return mlirDenseElementsAttrInt8Get(
|
return mlirDenseElementsAttrInt8Get(
|
||||||
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
shapedType, numElements, static_cast<const int8_t*>(tensorData));
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throwUnsupportedTensorError();
|
throwUnsupportedTensorError();
|
||||||
|
@ -382,9 +386,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
return {nullptr}; // Unreachable.
|
return {nullptr}; // Unreachable.
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
|
MlirAttribute torch_mlir::importAttribute(
|
||||||
torch::jit::Node *node,
|
MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) {
|
||||||
c10::Symbol symbol) {
|
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
auto kind = node->kindOf(symbol);
|
auto kind = node->kindOf(symbol);
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
@ -393,8 +396,8 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
|
||||||
// do that.
|
// do that.
|
||||||
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
|
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
|
||||||
case torch::jit::AttributeKind::f:
|
case torch::jit::AttributeKind::f:
|
||||||
return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
|
return mlirFloatAttrDoubleGet(
|
||||||
node->f(symbol));
|
context, mlirF64TypeGet(context), node->f(symbol));
|
||||||
case torch::jit::AttributeKind::s:
|
case torch::jit::AttributeKind::s:
|
||||||
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
|
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
|
||||||
case torch::jit::AttributeKind::t:
|
case torch::jit::AttributeKind::t:
|
||||||
|
@ -408,23 +411,23 @@ MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
MlirLocation torch_mlir::getMlirLocationFromNode(
|
||||||
torch::jit::Node *node) {
|
MlirContext context, torch::jit::Node* node) {
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
|
if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
|
||||||
const auto &sourceFiles = node->ss(c10::Symbol::attr("source_files"));
|
const auto& sourceFiles = node->ss(c10::Symbol::attr("source_files"));
|
||||||
const auto &lineNumbers = node->is(c10::Symbol::attr("line_numbers"));
|
const auto& lineNumbers = node->is(c10::Symbol::attr("line_numbers"));
|
||||||
const auto &functions = node->ss(c10::Symbol::attr("functions"));
|
const auto& functions = node->ss(c10::Symbol::attr("functions"));
|
||||||
|
|
||||||
// Chain a sequence of calls to construct single MlirLocation.
|
// Chain a sequence of calls to construct single MlirLocation.
|
||||||
for (const auto i : c10::irange(sourceFiles.size())) {
|
for (const auto i : c10::irange(sourceFiles.size())) {
|
||||||
MlirLocation newLoc = mlirLocationNameGet(
|
MlirLocation newLoc = mlirLocationNameGet(
|
||||||
context, toMlirStringRef(functions[i]),
|
context, toMlirStringRef(functions[i]),
|
||||||
mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]),
|
mlirLocationFileLineColGet(
|
||||||
lineNumbers[i],
|
context, toMlirStringRef(sourceFiles[i]), lineNumbers[i],
|
||||||
0 /* column is not available */
|
0 /* column is not available */
|
||||||
));
|
));
|
||||||
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
|
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
|
||||||
}
|
}
|
||||||
if (sourceFiles.size() == 1) {
|
if (sourceFiles.size() == 1) {
|
||||||
|
@ -433,7 +436,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
||||||
loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context));
|
loc = mlirLocationCallSiteGet(loc, mlirLocationUnknownGet(context));
|
||||||
}
|
}
|
||||||
} else if (auto flc = node->sourceRange().file_line_col()) {
|
} else if (auto flc = node->sourceRange().file_line_col()) {
|
||||||
const std::string &file = std::get<0>(*flc);
|
const std::string& file = std::get<0>(*flc);
|
||||||
int line = std::get<1>(*flc);
|
int line = std::get<1>(*flc);
|
||||||
int col = std::get<2>(*flc);
|
int col = std::get<2>(*flc);
|
||||||
loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col);
|
loc = mlirLocationFileLineColGet(context, toMlirStringRef(file), line, col);
|
||||||
|
@ -445,7 +448,7 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
||||||
locationName = scopeName;
|
locationName = scopeName;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const c10::FunctionSchema *schema = node->maybeSchema()) {
|
if (const c10::FunctionSchema* schema = node->maybeSchema()) {
|
||||||
if (!locationName.empty()) {
|
if (!locationName.empty()) {
|
||||||
locationName += "/";
|
locationName += "/";
|
||||||
}
|
}
|
||||||
|
@ -459,10 +462,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MlirType>
|
std::vector<MlirType> torch_mlir::getMlirTypesFromValues(
|
||||||
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
|
||||||
c10::ArrayRef<torch::jit::Value *> values,
|
const ImportOptions& importOptions) {
|
||||||
const ImportOptions &importOptions) {
|
|
||||||
std::vector<MlirType> ret;
|
std::vector<MlirType> ret;
|
||||||
for (auto value : values) {
|
for (auto value : values) {
|
||||||
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
|
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
|
||||||
|
@ -491,25 +493,24 @@ std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
|
MlirStringCallback printToStream = +[](MlirStringRef str, void* userData) {
|
||||||
std::stringstream *stream = static_cast<std::stringstream *>(userData);
|
std::stringstream* stream = static_cast<std::stringstream*>(userData);
|
||||||
stream->write(str.data, str.length);
|
stream->write(str.data, str.length);
|
||||||
};
|
};
|
||||||
msg << "unhandled: could not adjust static info for type from ";
|
msg << "unhandled: could not adjust static info for type from ";
|
||||||
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
|
mlirTypePrint(type, printToStream, static_cast<void*>(&msg));
|
||||||
msg << " to type ";
|
msg << " to type ";
|
||||||
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
|
mlirTypePrint(expectedType, printToStream, static_cast<void*>(&msg));
|
||||||
mlirEmitError(loc, msg.str().c_str());
|
mlirEmitError(loc, msg.str().c_str());
|
||||||
throw mlir_diagnostic_emitted();
|
throw mlir_diagnostic_emitted();
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirOperation
|
MlirOperation torch_mlir::createOperationFromSchema(
|
||||||
torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc,
|
MlirBlock appendToBlock, MlirLocation loc,
|
||||||
const c10::FunctionSchema &schema,
|
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
|
||||||
c10::ArrayRef<MlirType> resultTypes,
|
c10::ArrayRef<MlirValue> operands) {
|
||||||
c10::ArrayRef<MlirValue> operands) {
|
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
|
|
||||||
// Munge the name into the appropriate MLIR operation name.
|
// Munge the name into the appropriate MLIR operation name.
|
||||||
|
@ -519,15 +520,15 @@ torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc,
|
||||||
auto separatorPosition = opNameSuffix.find_first_of("::");
|
auto separatorPosition = opNameSuffix.find_first_of("::");
|
||||||
assert(separatorPosition != std::string::npos);
|
assert(separatorPosition != std::string::npos);
|
||||||
opNameSuffix.replace(separatorPosition, 2, ".");
|
opNameSuffix.replace(separatorPosition, 2, ".");
|
||||||
const std::string &overloadName = schema.overload_name();
|
const std::string& overloadName = schema.overload_name();
|
||||||
if (!overloadName.empty()) {
|
if (!overloadName.empty()) {
|
||||||
opNameSuffix = opNameSuffix + "." + overloadName;
|
opNameSuffix = opNameSuffix + "." + overloadName;
|
||||||
}
|
}
|
||||||
std::string opName = "torch." + opNameSuffix;
|
std::string opName = "torch." + opNameSuffix;
|
||||||
// If we have a registered op, use it!
|
// If we have a registered op, use it!
|
||||||
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
|
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
|
||||||
return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes,
|
return createMlirOperationAtEnd(
|
||||||
operands);
|
appendToBlock, opName, loc, resultTypes, operands);
|
||||||
}
|
}
|
||||||
// Oops, no registered op -- create an opaque wrapper so that import can
|
// Oops, no registered op -- create an opaque wrapper so that import can
|
||||||
// still succeed. This helps a common use case of filling out registered ops
|
// still succeed. This helps a common use case of filling out registered ops
|
|
@ -25,7 +25,7 @@ namespace torch_mlir {
|
||||||
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
||||||
class mlir_diagnostic_emitted : public std::runtime_error {
|
class mlir_diagnostic_emitted : public std::runtime_error {
|
||||||
public:
|
public:
|
||||||
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
mlir_diagnostic_emitted(const char* what) : std::runtime_error(what) {}
|
||||||
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -38,37 +38,36 @@ public:
|
||||||
/// for Python code).
|
/// for Python code).
|
||||||
///
|
///
|
||||||
/// Returns a null type on failure and emits a diagnostic.
|
/// Returns a null type on failure and emits a diagnostic.
|
||||||
MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
|
MlirType
|
||||||
c10::ScalarType scalarType);
|
getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
||||||
|
|
||||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||||
/// on failure and emits a diagnostic.
|
/// on failure and emits a diagnostic.
|
||||||
MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
MlirType getMlirTypeFromTorchType(
|
||||||
const c10::TypePtr &torchType,
|
MlirLocation loc, const c10::TypePtr& torchType,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions& importOptions = {});
|
||||||
|
|
||||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||||
///
|
///
|
||||||
/// This can differ from the type inferred from the block of a
|
/// This can differ from the type inferred from the block of a
|
||||||
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
||||||
MlirType getFunctionTypeFromSchema(MlirContext context,
|
MlirType getFunctionTypeFromSchema(
|
||||||
const c10::FunctionSchema &schema,
|
MlirContext context, const c10::FunctionSchema& schema,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions& importOptions = {});
|
||||||
|
|
||||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||||
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
|
MlirAttribute
|
||||||
MlirLocation loc);
|
convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc);
|
||||||
|
|
||||||
MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node,
|
MlirAttribute
|
||||||
c10::Symbol symbol);
|
importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol);
|
||||||
|
|
||||||
MlirLocation getMlirLocationFromNode(MlirContext context,
|
MlirLocation
|
||||||
torch::jit::Node *node);
|
getMlirLocationFromNode(MlirContext context, torch::jit::Node* node);
|
||||||
|
|
||||||
std::vector<MlirType>
|
std::vector<MlirType> getMlirTypesFromValues(
|
||||||
getMlirTypesFromValues(MlirLocation loc,
|
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
|
||||||
c10::ArrayRef<torch::jit::Value *> values,
|
const ImportOptions& importOptions = {});
|
||||||
const ImportOptions &importOptions = {});
|
|
||||||
|
|
||||||
std::vector<MlirValue> adjustStaticInformationForValues(
|
std::vector<MlirValue> adjustStaticInformationForValues(
|
||||||
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
|
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
|
||||||
|
@ -79,11 +78,10 @@ std::vector<MlirValue> adjustStaticInformationForValues(
|
||||||
///
|
///
|
||||||
/// The primary difficulty here is doing the appropriate name munging and
|
/// The primary difficulty here is doing the appropriate name munging and
|
||||||
/// checking if the have a registered op.
|
/// checking if the have a registered op.
|
||||||
MlirOperation createOperationFromSchema(MlirBlock appendToBlock,
|
MlirOperation createOperationFromSchema(
|
||||||
MlirLocation loc,
|
MlirBlock appendToBlock, MlirLocation loc,
|
||||||
const c10::FunctionSchema &schema,
|
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
|
||||||
c10::ArrayRef<MlirType> resultTypes,
|
c10::ArrayRef<MlirValue> operands);
|
||||||
c10::ArrayRef<MlirValue> operands);
|
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -14,12 +14,12 @@
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
#include <base_lazy_backend/backend_impl.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
#include <base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
|
#include <base_lazy_backend/utils/debug.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
|
#include <base_lazy_backend/utils/exception.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
#include <base_lazy_backend/utils/string_utils.h>
|
||||||
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,10 @@
|
||||||
#include "torch/csrc/lazy/core/config.h"
|
#include "torch/csrc/lazy/core/config.h"
|
||||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||||
|
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
#include <base_lazy_backend/utils/string_utils.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
|
#include <base_lazy_backend/utils/sys_utils.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>
|
#include <base_lazy_backend/utils/tensor_utils.h>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
# Subdirectories
|
# Subdirectories
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
add_subdirectory(csrc)
|
|
||||||
|
|
||||||
## Declare the sources of the Python module.
|
## Declare the sources of the Python module.
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
||||||
|
|
Loading…
Reference in New Issue