mirror of https://github.com/llvm/torch-mlir
Move files out of TorchPlugin/csrc/builder/ directory.
It was an unnecessary layer of indirection -- there was nothing outside of it, and it was just harder to follow the structure.pull/309/head
parent
8d27c41f21
commit
b738db34cd
|
@ -13,14 +13,13 @@ include_directories(BEFORE
|
|||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
add_library(TorchMLIRTorchPlugin SHARED
|
||||
builder/class_annotator.cpp
|
||||
builder/function_importer.cpp
|
||||
builder/module_builder.cpp
|
||||
builder/node_importer.cpp
|
||||
builder/ivalue_importer.cpp
|
||||
builder/python_bindings.cpp
|
||||
builder/torch_to_mlir_utils.cpp
|
||||
init_python_bindings.cpp
|
||||
class_annotator.cpp
|
||||
function_importer.cpp
|
||||
module_builder.cpp
|
||||
node_importer.cpp
|
||||
ivalue_importer.cpp
|
||||
python_bindings.cpp
|
||||
torch_to_mlir_utils.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(TorchMLIRTorchPlugin
|
||||
|
|
|
@ -34,13 +34,12 @@ static std::string indentString(const std::string &linePrefix,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ClassAnnotation::ClassAnnotation(c10::ClassTypePtr classType)
|
||||
: classType(classType) {
|
||||
: classType(classType) {
|
||||
attributeAnnotations.resize(classType->getAttributes().size());
|
||||
methodAnnotations.resize(classType->methods().size());
|
||||
}
|
||||
|
||||
std::vector<AttributeAnnotation> &
|
||||
ClassAnnotation::getAttributeAnnotations() {
|
||||
std::vector<AttributeAnnotation> &ClassAnnotation::getAttributeAnnotations() {
|
||||
// Halfhearted attempt to ensure consistency if the class type has
|
||||
// been mutated.
|
||||
//
|
||||
|
@ -53,8 +52,7 @@ ClassAnnotation::getAttributeAnnotations() {
|
|||
return attributeAnnotations;
|
||||
}
|
||||
|
||||
std::vector<MethodAnnotation> &
|
||||
ClassAnnotation::getMethodAnnotations() {
|
||||
std::vector<MethodAnnotation> &ClassAnnotation::getMethodAnnotations() {
|
||||
// Halfhearted attempt to ensure consistency if the class type has
|
||||
// been mutated.
|
||||
//
|
||||
|
@ -80,7 +78,8 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator,
|
|||
methodAnnotation.isExported = false;
|
||||
}
|
||||
for (auto &classAttribute : classType->getAttributes()) {
|
||||
if (auto childClassType = classAttribute.getType()->cast<c10::ClassType>()) {
|
||||
if (auto childClassType =
|
||||
classAttribute.getType()->cast<c10::ClassType>()) {
|
||||
exportNoneRecurse(classAnnotator, childClassType.get());
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +106,7 @@ void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
|||
ss << "class '" << classType->name()->qualifiedName()
|
||||
<< "' does not have a method or attribute called '"
|
||||
<< exportedPath.back() << "'";
|
||||
throw std::invalid_argument(ss.str());
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
ClassAnnotation &classAnnotation = getOrCreateClassAnnotation(classType);
|
||||
std::vector<AttributeAnnotation> &attributeAnnotations =
|
||||
|
@ -198,10 +197,9 @@ void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
|||
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
||||
"shapes/dtypes of a method of a class.");
|
||||
}
|
||||
c10::ClassType *classType =
|
||||
getClassAtPath(&rootClassType, c10::ArrayRef<std::string>(path)
|
||||
.slice(0, path.size() - 1)
|
||||
.vec());
|
||||
c10::ClassType *classType = getClassAtPath(
|
||||
&rootClassType,
|
||||
c10::ArrayRef<std::string>(path).slice(0, path.size() - 1).vec());
|
||||
|
||||
// Throw error if no method on the class of the specified name.
|
||||
torch::jit::Function *function = &classType->getMethod(path.back());
|
||||
|
@ -265,26 +263,26 @@ std::string AttributeAnnotation::toString(const std::string &name) {
|
|||
}
|
||||
|
||||
std::string ArgAnnotation::toString(int argIndex) {
|
||||
std::stringstream ss;
|
||||
ss << "ArgAnnotation(" << argIndex << ") {\n";
|
||||
ss << " dtype = " << (dtype ? c10::toString(*dtype) : "<none>") << "\n";
|
||||
ss << " shape = ";
|
||||
if (shape) {
|
||||
ss << "[";
|
||||
for (int i = 0, e = shape.value().size(); i != e; i++) {
|
||||
if (i) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << shape.value()[i];
|
||||
std::stringstream ss;
|
||||
ss << "ArgAnnotation(" << argIndex << ") {\n";
|
||||
ss << " dtype = " << (dtype ? c10::toString(*dtype) : "<none>") << "\n";
|
||||
ss << " shape = ";
|
||||
if (shape) {
|
||||
ss << "[";
|
||||
for (int i = 0, e = shape.value().size(); i != e; i++) {
|
||||
if (i) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
} else {
|
||||
ss << "<none>\n";
|
||||
ss << shape.value()[i];
|
||||
}
|
||||
ss << " hasValueSemantics = " << (hasValueSemantics ? "true" : "false")
|
||||
<< "\n";
|
||||
ss << "}\n";
|
||||
return ss.str();
|
||||
ss << "]\n";
|
||||
} else {
|
||||
ss << "<none>\n";
|
||||
}
|
||||
ss << " hasValueSemantics = " << (hasValueSemantics ? "true" : "false")
|
||||
<< "\n";
|
||||
ss << "}\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string MethodAnnotation::toString(const std::string &name) {
|
||||
|
@ -326,7 +324,7 @@ std::string ClassAnnotator::toString() {
|
|||
std::stringstream ss;
|
||||
ss << "ClassAnnotator {\n";
|
||||
for (auto &p : classAnnotations) {
|
||||
ss << indentString(" ", p.second->toString());
|
||||
ss << indentString(" ", p.second->toString());
|
||||
}
|
||||
ss << "}\n";
|
||||
return ss.str();
|
|
@ -21,7 +21,7 @@
|
|||
#ifndef TORCHMLIRPLUGIN_CSRC_CLASS_ANNOTATOR_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_CLASS_ANNOTATOR_H
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "pybind.h"
|
||||
|
||||
namespace torch_mlir {
|
||||
|
|
@ -10,8 +10,8 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "node_importer.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
@ -40,9 +40,7 @@ namespace torch_mlir {
|
|||
MlirOperation importJitFunctionAsFuncOp(
|
||||
MlirContext context, torch::jit::Function *function,
|
||||
std::function<MlirAttribute(int)> getArgAttribute =
|
||||
[](int) -> MlirAttribute {
|
||||
return {nullptr};
|
||||
});
|
||||
[](int) -> MlirAttribute { return {nullptr}; });
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
@ -19,8 +19,8 @@
|
|||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include "caffe2/core/scope_guard.h"
|
||||
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||
#include "caffe2/core/scope_guard.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -149,8 +149,7 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
MlirValue
|
||||
IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||
// TODO: Can we do better?
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
|
@ -10,8 +10,8 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "class_annotator.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_H
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "class_annotator.h"
|
||||
|
|
@ -211,12 +211,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()),
|
||||
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
auto createTerminator =
|
||||
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.If.yield", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.If.yield", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator));
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
|
@ -5,13 +5,12 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../pybind.h"
|
||||
// This is the top-level entry point for the MLIR <-> PyTorch bridge.
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include "../init_python_bindings.h"
|
||||
#include "module_builder.h"
|
||||
#include "class_annotator.h"
|
||||
#include "module_builder.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
namespace py = pybind11;
|
||||
|
@ -122,7 +121,7 @@ py::list GetRegisteredOps() {
|
|||
|
||||
} // namespace
|
||||
|
||||
void torch_mlir::InitBuilderBindings(py::module &m) {
|
||||
PYBIND11_MODULE(_torch_mlir, m) {
|
||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
||||
ModuleBuilder::bind(m);
|
||||
initClassAnnotatorBindings(m);
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "pybind.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
Loading…
Reference in New Issue