mirror of https://github.com/llvm/torch-mlir
First step of move common jit_ir_importer.
parent
606dc45896
commit
f1d9136210
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(csrc/jit_ir_importer)
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Static library with core functionality.
|
||||||
|
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
||||||
|
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
||||||
|
add_library(TorchMLIRJITIRImporter STATIC
|
||||||
|
class_annotator.cpp
|
||||||
|
function_importer.cpp
|
||||||
|
node_importer.cpp
|
||||||
|
ivalue_importer.cpp
|
||||||
|
torch_to_mlir_utils.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(TorchMLIRJITIRImporter
|
||||||
|
TorchMLIRAggregateCAPI
|
||||||
|
${TORCH_LIBRARIES}
|
||||||
|
)
|
||||||
|
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
|
||||||
|
target_include_directories(TorchMLIRJITIRImporter PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
)
|
||||||
|
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||||
|
OUTPUT_NAME lib_jit_ir_importer
|
||||||
|
PREFIX ""
|
||||||
|
SUFFIX ".a"
|
||||||
|
CXX_VISIBILITY_PRESET "default"
|
||||||
|
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
||||||
|
)
|
|
@ -18,8 +18,8 @@ using namespace torch_mlir;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Prefix every line of `s` with `linePrefix`.
|
// Prefix every line of `s` with `linePrefix`.
|
||||||
static std::string
|
static std::string indentString(const std::string &linePrefix,
|
||||||
indentString(const std::string& linePrefix, const std::string& s) {
|
const std::string &s) {
|
||||||
std::stringstream is(s);
|
std::stringstream is(s);
|
||||||
std::stringstream os;
|
std::stringstream os;
|
||||||
std::string line;
|
std::string line;
|
||||||
|
@ -46,8 +46,7 @@ std::vector<AttributeAnnotation>& ClassAnnotation::getAttributeAnnotations() {
|
||||||
// We can't easily guard against attributes being removed and
|
// We can't easily guard against attributes being removed and
|
||||||
// then other attributes being added, or types changed, etc. without
|
// then other attributes being added, or types changed, etc. without
|
||||||
// effectively mirroring the entire ClassType.
|
// effectively mirroring the entire ClassType.
|
||||||
assert(
|
assert(attributeAnnotations.size() == classType->getAttributes().size() &&
|
||||||
attributeAnnotations.size() == classType->getAttributes().size() &&
|
|
||||||
"annotations out of sync. class has been mutated");
|
"annotations out of sync. class has been mutated");
|
||||||
|
|
||||||
return attributeAnnotations;
|
return attributeAnnotations;
|
||||||
|
@ -58,8 +57,7 @@ std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() {
|
||||||
// been mutated.
|
// been mutated.
|
||||||
//
|
//
|
||||||
// We can't easily guard against methods being removed, added, or changed.
|
// We can't easily guard against methods being removed, added, or changed.
|
||||||
assert(
|
assert(methodAnnotations.size() == classType->methods().size() &&
|
||||||
methodAnnotations.size() == classType->methods().size() &&
|
|
||||||
"annotations out of sync. class has been mutated");
|
"annotations out of sync. class has been mutated");
|
||||||
|
|
||||||
return methodAnnotations;
|
return methodAnnotations;
|
||||||
|
@ -69,8 +67,8 @@ std::vector<MethodAnnotation>& ClassAnnotation::getMethodAnnotations() {
|
||||||
// ClassAnnotator
|
// ClassAnnotator
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void
|
static void exportNoneRecurse(ClassAnnotator &classAnnotator,
|
||||||
exportNoneRecurse(ClassAnnotator& classAnnotator, c10::ClassType* classType) {
|
c10::ClassType *classType) {
|
||||||
ClassAnnotation &classAnnotation =
|
ClassAnnotation &classAnnotation =
|
||||||
classAnnotator.getOrCreateClassAnnotation(classType);
|
classAnnotator.getOrCreateClassAnnotation(classType);
|
||||||
for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
|
for (auto &attributeAnnotation : classAnnotation.getAttributeAnnotations()) {
|
||||||
|
@ -91,14 +89,14 @@ void ClassAnnotator::exportNone(c10::ClassType& rootClassType) {
|
||||||
exportNoneRecurse(*this, &rootClassType);
|
exportNoneRecurse(*this, &rootClassType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::exportPath(
|
void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
||||||
c10::ClassType& rootClassType, std::vector<std::string> exportedPath) {
|
std::vector<std::string> exportedPath) {
|
||||||
if (exportedPath.size() == 0) {
|
if (exportedPath.size() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Empty exported path. Can only export a property of a class.");
|
"Empty exported path. Can only export a property of a class.");
|
||||||
}
|
}
|
||||||
c10::ClassType* classType = getClassAtPath(
|
c10::ClassType *classType =
|
||||||
&rootClassType, c10::ArrayRef<std::string>(exportedPath)
|
getClassAtPath(&rootClassType, c10::ArrayRef<std::string>(exportedPath)
|
||||||
.slice(0, exportedPath.size() - 1)
|
.slice(0, exportedPath.size() - 1)
|
||||||
.vec());
|
.vec());
|
||||||
|
|
||||||
|
@ -151,23 +149,23 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType* classType) {
|
||||||
return *it->second;
|
return *it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void fillArgAnnotations(
|
static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
|
||||||
MethodAnnotation& methodAnnotation,
|
std::vector<ArgAnnotation> argAnnotations,
|
||||||
std::vector<ArgAnnotation> argAnnotations, torch::jit::Function* function) {
|
torch::jit::Function *function) {
|
||||||
if (argAnnotations.size() != function->num_inputs()) {
|
if (argAnnotations.size() != function->num_inputs()) {
|
||||||
throw std::invalid_argument("Arg annotations should have one entry per "
|
throw std::invalid_argument("Arg annotations should have one entry per "
|
||||||
"function parameter (including self).");
|
"function parameter (including self).");
|
||||||
}
|
}
|
||||||
if (!methodAnnotation.argAnnotations.has_value()) {
|
if (!methodAnnotation.argAnnotations.has_value()) {
|
||||||
methodAnnotation.argAnnotations.emplace(
|
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
|
||||||
function->num_inputs(), ArgAnnotation{});
|
ArgAnnotation{});
|
||||||
}
|
}
|
||||||
|
|
||||||
methodAnnotation.argAnnotations = argAnnotations;
|
methodAnnotation.argAnnotations = argAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClassAnnotator::annotateArgs(
|
void ClassAnnotator::annotateArgs(c10::ClassType &rootClassType,
|
||||||
c10::ClassType& rootClassType, std::vector<std::string> path,
|
std::vector<std::string> path,
|
||||||
std::vector<ArgAnnotation> argAnnotations) {
|
std::vector<ArgAnnotation> argAnnotations) {
|
||||||
if (path.size() == 0) {
|
if (path.size() == 0) {
|
||||||
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
throw std::invalid_argument("Empty annotated path. Can only annotate "
|
||||||
|
@ -193,8 +191,8 @@ void ClassAnnotator::annotateArgs(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::ClassType* ClassAnnotator::getClassAtPath(
|
c10::ClassType *ClassAnnotator::getClassAtPath(c10::ClassType *rootClassType,
|
||||||
c10::ClassType* rootClassType, std::vector<std::string> path) {
|
std::vector<std::string> path) {
|
||||||
c10::ClassType *classType = rootClassType;
|
c10::ClassType *classType = rootClassType;
|
||||||
// Reverse so that pop_back gives us the initial atoms first.
|
// Reverse so that pop_back gives us the initial atoms first.
|
||||||
std::reverse(path.begin(), path.end());
|
std::reverse(path.begin(), path.end());
|
|
@ -141,8 +141,8 @@ public:
|
||||||
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
|
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
|
||||||
// have a submodule `a` and that submodule should have a method or attribute
|
// have a submodule `a` and that submodule should have a method or attribute
|
||||||
// `b`.
|
// `b`.
|
||||||
void exportPath(
|
void exportPath(c10::ClassType &rootClassType,
|
||||||
c10::ClassType& rootClassType, std::vector<std::string> exportedPath);
|
std::vector<std::string> exportedPath);
|
||||||
// Mark everything as not-exported.
|
// Mark everything as not-exported.
|
||||||
//
|
//
|
||||||
// This is kind of useless by itself, but together with `exportPath` allows
|
// This is kind of useless by itself, but together with `exportPath` allows
|
||||||
|
@ -159,8 +159,8 @@ public:
|
||||||
// a "has value semantics" boolean.
|
// a "has value semantics" boolean.
|
||||||
// These will be put into an `ArgAnnotation` struct -- see there for
|
// These will be put into an `ArgAnnotation` struct -- see there for
|
||||||
// precise definitions of the promised semantics of each entry.
|
// precise definitions of the promised semantics of each entry.
|
||||||
void annotateArgs(
|
void annotateArgs(c10::ClassType &rootClassType,
|
||||||
c10::ClassType& rootClassType, std::vector<std::string> path,
|
std::vector<std::string> path,
|
||||||
std::vector<ArgAnnotation> argAnnotations);
|
std::vector<ArgAnnotation> argAnnotations);
|
||||||
|
|
||||||
// The annotations collected so far.
|
// The annotations collected so far.
|
||||||
|
@ -183,8 +183,8 @@ private:
|
||||||
// Traverse `path` starting from `rootClassType` to find the ClassType
|
// Traverse `path` starting from `rootClassType` to find the ClassType
|
||||||
// of a presumed nested submodule. Throw an error if there is no such
|
// of a presumed nested submodule. Throw an error if there is no such
|
||||||
// submodule.
|
// submodule.
|
||||||
c10::ClassType*
|
c10::ClassType *getClassAtPath(c10::ClassType *rootClassType,
|
||||||
getClassAtPath(c10::ClassType* rootClassType, std::vector<std::string> path);
|
std::vector<std::string> path);
|
||||||
ClassAnnotationMap classAnnotations;
|
ClassAnnotationMap classAnnotations;
|
||||||
// Reverse mapping used to service getMethodAnnotationForFunction.
|
// Reverse mapping used to service getMethodAnnotationForFunction.
|
||||||
std::unordered_map<torch::jit::Function *, MethodAnnotation *>
|
std::unordered_map<torch::jit::Function *, MethodAnnotation *>
|
|
@ -63,8 +63,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
}
|
}
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(appendToBlock, "func.return", loc,
|
||||||
appendToBlock, "func.return", loc,
|
|
||||||
adjustStaticInformationForValues(
|
adjustStaticInformationForValues(
|
||||||
appendToBlock, loc, yieldedValues, resultTypes,
|
appendToBlock, loc, yieldedValues, resultTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
|
@ -99,9 +99,8 @@ namespace {
|
||||||
/// (PyTorch allows this!).
|
/// (PyTorch allows this!).
|
||||||
class IValueImporter {
|
class IValueImporter {
|
||||||
public:
|
public:
|
||||||
IValueImporter(
|
IValueImporter(MlirBlock importBlock, MlirContext context,
|
||||||
MlirBlock importBlock, MlirContext context, ClassAnnotator& annotator,
|
ClassAnnotator &annotator, const ImportOptions &importOptions)
|
||||||
const ImportOptions& importOptions)
|
|
||||||
: importBlock(importBlock), context(context), annotator(annotator),
|
: importBlock(importBlock), context(context), annotator(annotator),
|
||||||
importOptions(importOptions) {}
|
importOptions(importOptions) {}
|
||||||
|
|
||||||
|
@ -111,8 +110,7 @@ private:
|
||||||
MlirValue rawImportIValue(c10::IValue ivalue);
|
MlirValue rawImportIValue(c10::IValue ivalue);
|
||||||
MlirValue importTensor(c10::IValue ivalue);
|
MlirValue importTensor(c10::IValue ivalue);
|
||||||
MlirValue importModule(torch::jit::Module jitModule);
|
MlirValue importModule(torch::jit::Module jitModule);
|
||||||
void importMethod(
|
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
||||||
torch::jit::Function* function, MlirBlock classTypeBody,
|
|
||||||
const MethodAnnotation &methodAnnotation);
|
const MethodAnnotation &methodAnnotation);
|
||||||
void importClassType(c10::ClassType *classType);
|
void importClassType(c10::ClassType *classType);
|
||||||
void importCompilationUnit(torch::jit::CompilationUnit *cu);
|
void importCompilationUnit(torch::jit::CompilationUnit *cu);
|
||||||
|
@ -192,8 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(nnModuleRegion,
|
||||||
nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
InserterGuard inserterGuard(importBlock, nnModule);
|
InserterGuard inserterGuard(importBlock, nnModule);
|
||||||
|
|
||||||
|
@ -204,8 +202,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
||||||
const std::vector<c10::ClassAttribute> &classAttributes =
|
const std::vector<c10::ClassAttribute> &classAttributes =
|
||||||
currentModule.type()->getAttributes();
|
currentModule.type()->getAttributes();
|
||||||
assert(
|
assert(slots.size() == classAttributes.size() &&
|
||||||
slots.size() == classAttributes.size() &&
|
|
||||||
"mismatch between object and type!");
|
"mismatch between object and type!");
|
||||||
for (int i = 0, e = slots.size(); i < e; i++) {
|
for (int i = 0, e = slots.size(); i < e; i++) {
|
||||||
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
||||||
|
@ -261,8 +258,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirType type = torchMlirTorchBoolTypeGet(context);
|
MlirType type = torchMlirTorchBoolTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.bool", loc, type,
|
importBlock, "torch.constant.bool", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute("value",
|
||||||
"value", mlirBoolAttrGet(context, ivalue.toBool())));
|
mlirBoolAttrGet(context, ivalue.toBool())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isDouble()) {
|
if (ivalue.isDouble()) {
|
||||||
|
@ -270,17 +267,17 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.float", loc, type,
|
importBlock, "torch.constant.float", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirFloatAttrDoubleGet(
|
"value", mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
|
||||||
context, mlirF64TypeGet(context), ivalue.toDouble())));
|
ivalue.toDouble())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isInt()) {
|
if (ivalue.isInt()) {
|
||||||
MlirType type = torchMlirTorchIntTypeGet(context);
|
MlirType type = torchMlirTorchIntTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.int", loc, type,
|
importBlock, "torch.constant.int", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute("value",
|
||||||
"value", mlirIntegerAttrGet(
|
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64),
|
||||||
mlirIntegerTypeGet(context, 64), ivalue.toInt())));
|
ivalue.toInt())));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isList()) {
|
if (ivalue.isList()) {
|
||||||
|
@ -339,13 +336,13 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
torchMlirTorchStringTypeGet(context),
|
torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirStringAttrGet(
|
mlirStringAttrGet(context,
|
||||||
context, toMlirStringRef(ivalue.toString()->string()))));
|
toMlirStringRef(ivalue.toString()->string()))));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation =
|
||||||
importBlock, "torch.constant.none", loc,
|
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
||||||
torchMlirTorchNoneTypeGet(context));
|
torchMlirTorchNoneTypeGet(context));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
|
@ -440,8 +437,8 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
return tensorValue;
|
return tensorValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IValueImporter::importMethod(
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
torch::jit::Function* function, MlirBlock classTypeBody,
|
MlirBlock classTypeBody,
|
||||||
const MethodAnnotation &methodAnnotation) {
|
const MethodAnnotation &methodAnnotation) {
|
||||||
// The function's name becomes the MLIR symbol table name of the imported func
|
// The function's name becomes the MLIR symbol table name of the imported func
|
||||||
// when we import the compilation unit.
|
// when we import the compilation unit.
|
||||||
|
@ -568,8 +565,8 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
|
||||||
int64_t dummy;
|
int64_t dummy;
|
||||||
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||||
if (hasValueSemantics) {
|
if (hasValueSemantics) {
|
||||||
typeBound = torchMlirTorchValueTensorTypeGet(
|
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
|
||||||
context, shape.size(), shapeData, dtype);
|
shapeData, dtype);
|
||||||
} else {
|
} else {
|
||||||
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
||||||
context, shape.size(), shapeData, dtype);
|
context, shape.size(), shapeData, dtype);
|
||||||
|
@ -597,9 +594,10 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit* cu) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue torch_mlir::importIValue(
|
MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||||
c10::IValue ivalue, MlirBlock block, MlirContext context,
|
MlirContext context,
|
||||||
ClassAnnotator& annotator, const ImportOptions& importOptions) {
|
ClassAnnotator &annotator,
|
||||||
|
const ImportOptions &importOptions) {
|
||||||
// When debugging module importing, it can be useful to dump as so:
|
// When debugging module importing, it can be useful to dump as so:
|
||||||
// if (ivalue.isModule())
|
// if (ivalue.isModule())
|
||||||
// ivalue.toModule().dump(true, false, false);
|
// ivalue.toModule().dump(true, false, false);
|
|
@ -25,9 +25,9 @@ namespace torch_mlir {
|
||||||
|
|
||||||
/// Main entry-point for importing torch IValue's .
|
/// Main entry-point for importing torch IValue's .
|
||||||
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
||||||
MlirValue importIValue(
|
MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
|
||||||
c10::IValue ivalue, MlirBlock block, MlirContext context,
|
ClassAnnotator &annotator,
|
||||||
ClassAnnotator& annotator, const ImportOptions& importOptions);
|
const ImportOptions &importOptions);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -30,50 +30,50 @@ inline MlirStringRef toMlirStringRef(const char* s) {
|
||||||
return mlirStringRefCreate(s, std::strlen(s));
|
return mlirStringRefCreate(s, std::strlen(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline MlirNamedAttribute
|
inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
|
||||||
toMlirNamedAttribute(const char* s, MlirAttribute attr) {
|
MlirAttribute attr) {
|
||||||
MlirContext context = mlirAttributeGetContext(attr);
|
MlirContext context = mlirAttributeGetContext(attr);
|
||||||
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
|
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
|
||||||
return mlirNamedAttributeGet(ident, attr);
|
return mlirNamedAttributeGet(ident, attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
MlirOperationState& state, MlirNamedAttribute namedAttr) {
|
MlirNamedAttribute namedAttr) {
|
||||||
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
|
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
addToMlirOperationState(MlirOperationState& state, MlirRegion region) {
|
MlirRegion region) {
|
||||||
mlirOperationStateAddOwnedRegions(&state, 1, ®ion);
|
mlirOperationStateAddOwnedRegions(&state, 1, ®ion);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
addToMlirOperationState(MlirOperationState& state, MlirValue value) {
|
MlirValue value) {
|
||||||
mlirOperationStateAddOperands(&state, 1, &value);
|
mlirOperationStateAddOperands(&state, 1, &value);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
MlirOperationState& state, const std::vector<MlirValue>& values) {
|
const std::vector<MlirValue> &values) {
|
||||||
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
MlirOperationState& state, c10::ArrayRef<MlirValue> values) {
|
c10::ArrayRef<MlirValue> values) {
|
||||||
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
mlirOperationStateAddOperands(&state, values.size(), values.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
addToMlirOperationState(MlirOperationState& state, MlirType resultType) {
|
MlirType resultType) {
|
||||||
mlirOperationStateAddResults(&state, 1, &resultType);
|
mlirOperationStateAddResults(&state, 1, &resultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
MlirOperationState& state, const std::vector<MlirType>& resultTypes) {
|
const std::vector<MlirType> &resultTypes) {
|
||||||
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void addToMlirOperationState(
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
MlirOperationState& state, c10::ArrayRef<MlirType> resultTypes) {
|
c10::ArrayRef<MlirType> resultTypes) {
|
||||||
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,27 +87,27 @@ void addToMlirOperationState(MlirOperationState& state, c10::optional<T> o) {
|
||||||
inline void addToMlirOperationState(MlirOperationState &state) {}
|
inline void addToMlirOperationState(MlirOperationState &state) {}
|
||||||
|
|
||||||
template <typename T, typename U, typename... Ts>
|
template <typename T, typename U, typename... Ts>
|
||||||
void addToMlirOperationState(
|
void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u,
|
||||||
MlirOperationState& state, T&& t, U&& u, Ts&&... ts) {
|
Ts &&...ts) {
|
||||||
addToMlirOperationState(state, std::forward<T>(t));
|
addToMlirOperationState(state, std::forward<T>(t));
|
||||||
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
|
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
MlirOperation
|
MlirOperation createMlirOperation(std::string name, MlirLocation loc,
|
||||||
createMlirOperation(std::string name, MlirLocation loc, Ts&&... ts) {
|
Ts &&...ts) {
|
||||||
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
|
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
|
||||||
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
||||||
return mlirOperationCreate(&state);
|
return mlirOperationCreate(&state);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
MlirOperation createMlirOperationAtEnd(
|
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
|
||||||
MlirBlock block, std::string name, MlirLocation loc, Ts&&... ts) {
|
MlirLocation loc, Ts &&...ts) {
|
||||||
MlirOperation operation =
|
MlirOperation operation =
|
||||||
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
|
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
|
||||||
mlirBlockInsertOwnedOperationBefore(
|
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
|
||||||
block, mlirBlockGetTerminator(block), operation);
|
operation);
|
||||||
return operation;
|
return operation;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,8 +33,7 @@ class NodeImporter {
|
||||||
public:
|
public:
|
||||||
NodeImporter(MlirContext context) : context(context) {}
|
NodeImporter(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
void importNode(
|
void importNode(Node *node, MlirBlock appendToBlock,
|
||||||
Node* node, MlirBlock appendToBlock,
|
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
MlirBlock importBlock(
|
MlirBlock importBlock(
|
||||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||||
|
@ -42,8 +41,8 @@ public:
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock createBlockFor(
|
MlirBlock createBlockFor(Block *jitBlock,
|
||||||
Block* jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
void mapValue(Value *jitValue, MlirValue value);
|
void mapValue(Value *jitValue, MlirValue value);
|
||||||
void mapResults(Node *node, MlirOperation operation);
|
void mapResults(Node *node, MlirOperation operation);
|
||||||
|
@ -66,8 +65,7 @@ static std::vector<MlirValue>
|
||||||
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
||||||
if (inputs.empty())
|
if (inputs.empty())
|
||||||
return inputs;
|
return inputs;
|
||||||
assert(
|
assert(inputs.size() % 2 == 0 &&
|
||||||
inputs.size() % 2 == 0 &&
|
|
||||||
"DictConstruct must have even number of operands");
|
"DictConstruct must have even number of operands");
|
||||||
|
|
||||||
std::vector<MlirValue> rearranged;
|
std::vector<MlirValue> rearranged;
|
||||||
|
@ -80,8 +78,8 @@ rearrangeDictConstructInputs(std::vector<MlirValue>& inputs) {
|
||||||
return rearranged;
|
return rearranged;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeImporter::importNode(
|
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
Node* node, MlirBlock appendToBlock, const ImportOptions& importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
MlirLocation loc = getMlirLocationFromNode(context, node);
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
||||||
auto kind = node->kind();
|
auto kind = node->kind();
|
||||||
|
|
||||||
|
@ -140,8 +138,8 @@ void NodeImporter::importNode(
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
});
|
});
|
||||||
createAndMapTrivialNode(
|
createAndMapTrivialNode(node,
|
||||||
node, "torch.prim." + std::string(kind.toUnqualString()),
|
"torch.prim." + std::string(kind.toUnqualString()),
|
||||||
[&](std::vector<MlirValue> &inputs) {
|
[&](std::vector<MlirValue> &inputs) {
|
||||||
assert(containedTypes.size() == inputs.size());
|
assert(containedTypes.size() == inputs.size());
|
||||||
return adjustStaticInformationForValues(
|
return adjustStaticInformationForValues(
|
||||||
|
@ -151,8 +149,8 @@ void NodeImporter::importNode(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case c10::prim::DictConstruct: {
|
case c10::prim::DictConstruct: {
|
||||||
createAndMapTrivialNode(
|
createAndMapTrivialNode(node,
|
||||||
node, "torch.prim." + std::string(kind.toUnqualString()),
|
"torch.prim." + std::string(kind.toUnqualString()),
|
||||||
rearrangeDictConstructInputs);
|
rearrangeDictConstructInputs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -171,34 +169,32 @@ void NodeImporter::importNode(
|
||||||
auto output = node->output();
|
auto output = node->output();
|
||||||
MlirOperation op;
|
MlirOperation op;
|
||||||
if (output->type()->cast<c10::NoneType>()) {
|
if (output->type()->cast<c10::NoneType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation("torch.constant.none", loc,
|
||||||
"torch.constant.none", loc, torchMlirTorchNoneTypeGet(context));
|
torchMlirTorchNoneTypeGet(context));
|
||||||
} else if (output->type()->cast<c10::BoolType>()) {
|
} else if (output->type()->cast<c10::BoolType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
|
||||||
mlirBoolAttrGet(
|
c10::attr::value)))));
|
||||||
context, static_cast<bool>(node->i(c10::attr::value)))));
|
|
||||||
} else if (output->type()->cast<c10::IntType>()) {
|
} else if (output->type()->cast<c10::IntType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.int", loc,
|
"torch.constant.int", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute("value",
|
||||||
"value", importAttribute(loc, node, c10::attr::value)));
|
importAttribute(loc, node, c10::attr::value)));
|
||||||
} else if (output->type()->cast<c10::FloatType>()) {
|
} else if (output->type()->cast<c10::FloatType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.float", loc,
|
"torch.constant.float", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute("value",
|
||||||
"value", importAttribute(loc, node, c10::attr::value)));
|
importAttribute(loc, node, c10::attr::value)));
|
||||||
} else if (output->type()->cast<c10::StringType>()) {
|
} else if (output->type()->cast<c10::StringType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||||
mlirStringAttrGet(
|
c10::attr::value)))));
|
||||||
context, toMlirStringRef(node->s(c10::attr::value)))));
|
|
||||||
} else if (output->type()->cast<c10::TensorType>()) {
|
} else if (output->type()->cast<c10::TensorType>()) {
|
||||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||||
|
@ -217,26 +213,24 @@ void NodeImporter::importNode(
|
||||||
"torch.constant.device", loc,
|
"torch.constant.device", loc,
|
||||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||||
mlirStringAttrGet(
|
c10::attr::value)))));
|
||||||
context, toMlirStringRef(node->s(c10::attr::value)))));
|
|
||||||
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
||||||
torch::jit::Function *function = functionType->function();
|
torch::jit::Function *function = functionType->function();
|
||||||
const std::string &symName = function->qualname().qualifiedName();
|
const std::string &symName = function->qualname().qualifiedName();
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"func.constant", loc,
|
"func.constant", loc,
|
||||||
getFunctionTypeFromSchema(
|
getFunctionTypeFromSchema(context, function->getSchema(),
|
||||||
context, function->getSchema(), importOptions),
|
importOptions),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||||
} else if (
|
} else if (output->type()->cast<c10::ListType>() ||
|
||||||
output->type()->cast<c10::ListType>() ||
|
|
||||||
output->type()->cast<c10::TupleType>()) {
|
output->type()->cast<c10::TupleType>()) {
|
||||||
ClassAnnotator dummyAnnotator;
|
ClassAnnotator dummyAnnotator;
|
||||||
MlirValue listOrTupleValue = importIValue(
|
MlirValue listOrTupleValue =
|
||||||
node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator,
|
importIValue(node->ival(c10::attr::value), appendToBlock, context,
|
||||||
importOptions);
|
dummyAnnotator, importOptions);
|
||||||
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
|
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
|
||||||
return; // Early return, since `importIValue` already added op to block.
|
return; // Early return, since `importIValue` already added op to block.
|
||||||
} else {
|
} else {
|
||||||
|
@ -264,20 +258,19 @@ void NodeImporter::importNode(
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
std::vector<MlirType> terminatorOperandTypes = {
|
std::vector<MlirType> terminatorOperandTypes = {
|
||||||
torchMlirTorchBoolTypeGet(context)};
|
torchMlirTorchBoolTypeGet(context)};
|
||||||
terminatorOperandTypes.insert(
|
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||||
terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end());
|
resultTypes.begin(), resultTypes.end());
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.Loop.condition", loc,
|
appendToBlock, "torch.prim.Loop.condition", loc,
|
||||||
adjustStaticInformationForValues(
|
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
||||||
appendToBlock, loc, yieldedValues, terminatorOperandTypes,
|
terminatorOperandTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||||
mlirOperationGetRegion(operation, 0),
|
importBlock(node->blocks()[0], createTerminator,
|
||||||
importBlock(
|
c10::nullopt, importOptions));
|
||||||
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,18 +285,16 @@ void NodeImporter::importNode(
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.If.yield", loc,
|
appendToBlock, "torch.prim.If.yield", loc,
|
||||||
adjustStaticInformationForValues(
|
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
||||||
appendToBlock, loc, yieldedValues, resultTypes,
|
resultTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||||
mlirOperationGetRegion(operation, 0),
|
importBlock(node->blocks()[0], createTerminator,
|
||||||
importBlock(
|
c10::nullopt, importOptions));
|
||||||
node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
||||||
mlirRegionAppendOwnedBlock(
|
importBlock(node->blocks()[1], createTerminator,
|
||||||
mlirOperationGetRegion(operation, 1),
|
c10::nullopt, importOptions));
|
||||||
importBlock(
|
|
||||||
node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,8 +314,8 @@ void NodeImporter::importNode(
|
||||||
adjustStaticInformationForValues(
|
adjustStaticInformationForValues(
|
||||||
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
||||||
expectedTypes, /*userAllowsRefinement=*/false),
|
expectedTypes, /*userAllowsRefinement=*/false),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute("name",
|
||||||
"name", importAttribute(loc, node, c10::attr::name)));
|
importAttribute(loc, node, c10::attr::name)));
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -348,9 +339,9 @@ void NodeImporter::importNode(
|
||||||
// promoted result dtype for a PyTorch computation. Here we turn the call to
|
// promoted result dtype for a PyTorch computation. Here we turn the call to
|
||||||
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
|
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
|
||||||
if (functionName == "__torch_mlir_internal_promote_dtypes") {
|
if (functionName == "__torch_mlir_internal_promote_dtypes") {
|
||||||
operation = createMlirOperationAtEnd(
|
operation =
|
||||||
appendToBlock, "torch.promote_dtypes", loc, resultTypes,
|
createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc,
|
||||||
adjustedFuncArgs);
|
resultTypes, adjustedFuncArgs);
|
||||||
} else {
|
} else {
|
||||||
operation = createMlirOperationAtEnd(
|
operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, "func.call_indirect", loc, resultTypes,
|
appendToBlock, "func.call_indirect", loc, resultTypes,
|
||||||
|
@ -369,8 +360,8 @@ void NodeImporter::importNode(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::importBlock(
|
MlirBlock
|
||||||
Block* jitBlock, CreateTerminatorFn createTerminator,
|
NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
||||||
|
@ -394,9 +385,9 @@ MlirBlock NodeImporter::createBlockFor(
|
||||||
else
|
else
|
||||||
assert(blockArgTypes->size() == paramNodeTypes.size());
|
assert(blockArgTypes->size() == paramNodeTypes.size());
|
||||||
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
||||||
MlirBlock block = mlirBlockCreate(
|
MlirBlock block =
|
||||||
blockArgTypes.value().size(), blockArgTypes.value().data(),
|
mlirBlockCreate(blockArgTypes.value().size(),
|
||||||
blockArgLocs.data());
|
blockArgTypes.value().data(), blockArgLocs.data());
|
||||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||||
Value *jitValue = paramNode->outputs()[i];
|
Value *jitValue = paramNode->outputs()[i];
|
||||||
MlirValue value = mlirBlockGetArgument(block, i);
|
MlirValue value = mlirBlockGetArgument(block, i);
|
||||||
|
@ -415,16 +406,15 @@ void NodeImporter::mapValue(Value* jitValue, MlirValue value) {
|
||||||
valueMap[jitValue] = value;
|
valueMap[jitValue] = value;
|
||||||
}
|
}
|
||||||
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
|
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
|
||||||
assert(
|
assert(node->outputs().size() ==
|
||||||
node->outputs().size() == (size_t)mlirOperationGetNumResults(operation));
|
(size_t)mlirOperationGetNumResults(operation));
|
||||||
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
||||||
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
|
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
|
||||||
auto it = valueMap.find(jitValue);
|
auto it = valueMap.find(jitValue);
|
||||||
assert(
|
assert(it != valueMap.end() &&
|
||||||
it != valueMap.end() &&
|
|
||||||
"trying to get mapping for jitValue that is not mapped yet!");
|
"trying to get mapping for jitValue that is not mapped yet!");
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
@ -437,11 +427,12 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value*> values) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock torch_mlir::importBlock(
|
MlirBlock
|
||||||
MlirContext context, Block* jitBlock, CreateTerminatorFn createTerminator,
|
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||||
|
CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
NodeImporter importer(context);
|
NodeImporter importer(context);
|
||||||
return importer.importBlock(
|
return importer.importBlock(jitBlock, createTerminator, blockArgTypes,
|
||||||
jitBlock, createTerminator, blockArgTypes, importOptions);
|
importOptions);
|
||||||
}
|
}
|
|
@ -36,8 +36,8 @@ using CreateTerminatorFn =
|
||||||
/// are required to be for correctness. The code will internally attempt to
|
/// are required to be for correctness. The code will internally attempt to
|
||||||
/// adjust the types to the block argument types.
|
/// adjust the types to the block argument types.
|
||||||
/// TODO: Formalize what type conversions are allowed here.
|
/// TODO: Formalize what type conversions are allowed here.
|
||||||
MlirBlock importBlock(
|
MlirBlock
|
||||||
MlirContext context, torch::jit::Block* jitBlock,
|
importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||||
CreateTerminatorFn createTerminator,
|
CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
|
@ -26,8 +26,8 @@
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
static MlirType getMlirTypeForTorchScalarTypeRaw(
|
static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
|
||||||
MlirContext context, c10::ScalarType scalarType) {
|
c10::ScalarType scalarType) {
|
||||||
using c10::ScalarType;
|
using c10::ScalarType;
|
||||||
switch (scalarType) {
|
switch (scalarType) {
|
||||||
case ScalarType::Byte:
|
case ScalarType::Byte:
|
||||||
|
@ -69,8 +69,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torch_mlir::getMlirTypeForTorchScalarType(
|
MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
|
||||||
MlirLocation loc, c10::ScalarType scalarType) {
|
c10::ScalarType scalarType) {
|
||||||
auto type =
|
auto type =
|
||||||
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
|
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
|
||||||
if (mlirTypeIsNull(type)) {
|
if (mlirTypeIsNull(type)) {
|
||||||
|
@ -98,8 +98,8 @@ MlirType torch_mlir::getMlirTypeForTorchScalarType(
|
||||||
// There is no generic way to import custom classes (or their types), so we
|
// There is no generic way to import custom classes (or their types), so we
|
||||||
// have to name match them here (and the relevant code in the ivalue
|
// have to name match them here (and the relevant code in the ivalue
|
||||||
// importer) and create special IR constructs for them.
|
// importer) and create special IR constructs for them.
|
||||||
static MlirType mapCustomClassType(
|
static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
||||||
MlirContext context, MlirLocation loc, const c10::ClassTypePtr& classType) {
|
const c10::ClassTypePtr &classType) {
|
||||||
// If the type is unnamed, it cannot be a custom class.
|
// If the type is unnamed, it cannot be a custom class.
|
||||||
if (!classType->name().has_value()) {
|
if (!classType->name().has_value()) {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
|
@ -126,8 +126,9 @@ static MlirType mapCustomClassType(
|
||||||
throw mlir_diagnostic_emitted();
|
throw mlir_diagnostic_emitted();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torch_mlir::getMlirTypeFromTorchType(
|
MlirType
|
||||||
MlirLocation loc, const c10::TypePtr& torchType,
|
torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
|
const c10::TypePtr &torchType,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
using c10::TypeKind;
|
using c10::TypeKind;
|
||||||
|
@ -140,8 +141,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
: torchMlirTorchNonValueTensorTypeGet;
|
: torchMlirTorchNonValueTensorTypeGet;
|
||||||
|
|
||||||
if (importOptions.ignoreExistingTensorShapesAndDtypes) {
|
if (importOptions.ignoreExistingTensorShapesAndDtypes) {
|
||||||
return getMlirTensorType(
|
return getMlirTensorType(context,
|
||||||
context,
|
|
||||||
/*numSizes=*/-1,
|
/*numSizes=*/-1,
|
||||||
/*optionalSizes=*/nullptr,
|
/*optionalSizes=*/nullptr,
|
||||||
/*optionalDtype=*/{nullptr});
|
/*optionalDtype=*/{nullptr});
|
||||||
|
@ -159,8 +159,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
auto &sizes = tensorType->symbolic_sizes();
|
auto &sizes = tensorType->symbolic_sizes();
|
||||||
if (!sizes.rank()) {
|
if (!sizes.rank()) {
|
||||||
// Unranked.
|
// Unranked.
|
||||||
return getMlirTensorType(
|
return getMlirTensorType(context,
|
||||||
context,
|
|
||||||
/*numSizes=*/-1,
|
/*numSizes=*/-1,
|
||||||
/*optionalSizes=*/nullptr,
|
/*optionalSizes=*/nullptr,
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
|
@ -181,8 +180,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
// case. So use a dummy data pointer.
|
// case. So use a dummy data pointer.
|
||||||
int64_t dummy;
|
int64_t dummy;
|
||||||
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
|
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
|
||||||
return getMlirTensorType(
|
return getMlirTensorType(context, dims.size(),
|
||||||
context, dims.size(),
|
|
||||||
/*optionalSizes=*/dimsData,
|
/*optionalSizes=*/dimsData,
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
elementType);
|
elementType);
|
||||||
|
@ -214,8 +212,8 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
containedTypes.push_back(
|
containedTypes.push_back(
|
||||||
getMlirTypeFromTorchType(loc, type, importOptions));
|
getMlirTypeFromTorchType(loc, type, importOptions));
|
||||||
}
|
}
|
||||||
return torchMlirTorchTupleTypeGet(
|
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
||||||
context, containedTypes.size(), containedTypes.data());
|
containedTypes.data());
|
||||||
}
|
}
|
||||||
case TypeKind::UnionType: {
|
case TypeKind::UnionType: {
|
||||||
std::vector<MlirType> containedTypes;
|
std::vector<MlirType> containedTypes;
|
||||||
|
@ -223,8 +221,8 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
torchType->cast<c10::UnionType>()->containedTypes()) {
|
torchType->cast<c10::UnionType>()->containedTypes()) {
|
||||||
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
||||||
}
|
}
|
||||||
return torchMlirTorchUnionTypeGet(
|
return torchMlirTorchUnionTypeGet(context, containedTypes.size(),
|
||||||
context, containedTypes.size(), containedTypes.data());
|
containedTypes.data());
|
||||||
}
|
}
|
||||||
case TypeKind::ListType: {
|
case TypeKind::ListType: {
|
||||||
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
||||||
|
@ -268,8 +266,9 @@ MlirType torch_mlir::getMlirTypeFromTorchType(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torch_mlir::getFunctionTypeFromSchema(
|
MlirType
|
||||||
MlirContext context, const c10::FunctionSchema& schema,
|
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||||
|
const c10::FunctionSchema &schema,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
auto mapType = [&](const c10::TypePtr &torchType) {
|
auto mapType = [&](const c10::TypePtr &torchType) {
|
||||||
|
@ -284,20 +283,17 @@ MlirType torch_mlir::getFunctionTypeFromSchema(
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<MlirType> inputTypes =
|
std::vector<MlirType> inputTypes =
|
||||||
c10::fmap(schema.arguments(), [&](const c10::Argument& arg) {
|
c10::fmap(schema.arguments(),
|
||||||
return mapType(arg.type());
|
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||||
});
|
|
||||||
std::vector<MlirType> outputTypes =
|
std::vector<MlirType> outputTypes =
|
||||||
c10::fmap(schema.returns(), [&](const c10::Argument& arg) {
|
c10::fmap(schema.returns(),
|
||||||
return mapType(arg.type());
|
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||||
});
|
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||||
return mlirFunctionTypeGet(
|
outputTypes.size(), outputTypes.data());
|
||||||
context, inputTypes.size(), inputTypes.data(), outputTypes.size(),
|
|
||||||
outputTypes.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
|
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
at::Tensor tensor, MlirLocation loc) {
|
MlirLocation loc) {
|
||||||
using at::ScalarType;
|
using at::ScalarType;
|
||||||
|
|
||||||
auto throwUnsupportedTensorError = [&]() {
|
auto throwUnsupportedTensorError = [&]() {
|
||||||
|
@ -312,8 +308,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
|
||||||
|
|
||||||
// The flat number of bytes throws an exception for tensors that are not
|
// The flat number of bytes throws an exception for tensors that are not
|
||||||
// dense and accessible as such.
|
// dense and accessible as such.
|
||||||
at::checkLayout(
|
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
|
||||||
at::CheckedFrom("accessing contiguous"), tensor, c10::Layout::Strided);
|
c10::Layout::Strided);
|
||||||
|
|
||||||
// Construct the ShapedType.
|
// Construct the ShapedType.
|
||||||
|
|
||||||
|
@ -358,8 +354,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
|
||||||
// the unnecessary copying into an array four times as large.
|
// the unnecessary copying into an array four times as large.
|
||||||
const int8_t *elements = static_cast<const int8_t *>(tensorData);
|
const int8_t *elements = static_cast<const int8_t *>(tensorData);
|
||||||
std::vector<int> tensorDataVector(elements, elements + numElements);
|
std::vector<int> tensorDataVector(elements, elements + numElements);
|
||||||
return mlirDenseElementsAttrBoolGet(
|
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
|
||||||
shapedType, numElements, tensorDataVector.data());
|
tensorDataVector.data());
|
||||||
} break;
|
} break;
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return mlirDenseElementsAttrInt8Get(
|
return mlirDenseElementsAttrInt8Get(
|
||||||
|
@ -386,8 +382,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(
|
||||||
return {nullptr}; // Unreachable.
|
return {nullptr}; // Unreachable.
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirAttribute torch_mlir::importAttribute(
|
MlirAttribute torch_mlir::importAttribute(MlirLocation loc,
|
||||||
MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol) {
|
torch::jit::Node *node,
|
||||||
|
c10::Symbol symbol) {
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
auto kind = node->kindOf(symbol);
|
auto kind = node->kindOf(symbol);
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
@ -396,8 +393,8 @@ MlirAttribute torch_mlir::importAttribute(
|
||||||
// do that.
|
// do that.
|
||||||
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
|
return mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), node->i(symbol));
|
||||||
case torch::jit::AttributeKind::f:
|
case torch::jit::AttributeKind::f:
|
||||||
return mlirFloatAttrDoubleGet(
|
return mlirFloatAttrDoubleGet(context, mlirF64TypeGet(context),
|
||||||
context, mlirF64TypeGet(context), node->f(symbol));
|
node->f(symbol));
|
||||||
case torch::jit::AttributeKind::s:
|
case torch::jit::AttributeKind::s:
|
||||||
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
|
return mlirStringAttrGet(context, toMlirStringRef(node->s(symbol)));
|
||||||
case torch::jit::AttributeKind::t:
|
case torch::jit::AttributeKind::t:
|
||||||
|
@ -411,8 +408,8 @@ MlirAttribute torch_mlir::importAttribute(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirLocation torch_mlir::getMlirLocationFromNode(
|
MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
||||||
MlirContext context, torch::jit::Node* node) {
|
torch::jit::Node *node) {
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
|
if (node->hasAttribute(c10::Symbol::attr("source_files"))) {
|
||||||
|
@ -424,8 +421,8 @@ MlirLocation torch_mlir::getMlirLocationFromNode(
|
||||||
for (const auto i : c10::irange(sourceFiles.size())) {
|
for (const auto i : c10::irange(sourceFiles.size())) {
|
||||||
MlirLocation newLoc = mlirLocationNameGet(
|
MlirLocation newLoc = mlirLocationNameGet(
|
||||||
context, toMlirStringRef(functions[i]),
|
context, toMlirStringRef(functions[i]),
|
||||||
mlirLocationFileLineColGet(
|
mlirLocationFileLineColGet(context, toMlirStringRef(sourceFiles[i]),
|
||||||
context, toMlirStringRef(sourceFiles[i]), lineNumbers[i],
|
lineNumbers[i],
|
||||||
0 /* column is not available */
|
0 /* column is not available */
|
||||||
));
|
));
|
||||||
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
|
loc = (i == 0 ? newLoc : mlirLocationCallSiteGet(newLoc, loc));
|
||||||
|
@ -462,8 +459,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MlirType> torch_mlir::getMlirTypesFromValues(
|
std::vector<MlirType>
|
||||||
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
|
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
||||||
|
c10::ArrayRef<torch::jit::Value *> values,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
std::vector<MlirType> ret;
|
std::vector<MlirType> ret;
|
||||||
for (auto value : values) {
|
for (auto value : values) {
|
||||||
|
@ -507,9 +505,10 @@ std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirOperation torch_mlir::createOperationFromSchema(
|
MlirOperation
|
||||||
MlirBlock appendToBlock, MlirLocation loc,
|
torch_mlir::createOperationFromSchema(MlirBlock appendToBlock, MlirLocation loc,
|
||||||
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
|
const c10::FunctionSchema &schema,
|
||||||
|
c10::ArrayRef<MlirType> resultTypes,
|
||||||
c10::ArrayRef<MlirValue> operands) {
|
c10::ArrayRef<MlirValue> operands) {
|
||||||
MlirContext context = mlirLocationGetContext(loc);
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
|
|
||||||
|
@ -527,8 +526,8 @@ MlirOperation torch_mlir::createOperationFromSchema(
|
||||||
std::string opName = "torch." + opNameSuffix;
|
std::string opName = "torch." + opNameSuffix;
|
||||||
// If we have a registered op, use it!
|
// If we have a registered op, use it!
|
||||||
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
|
if (mlirContextIsRegisteredOperation(context, toMlirStringRef(opName))) {
|
||||||
return createMlirOperationAtEnd(
|
return createMlirOperationAtEnd(appendToBlock, opName, loc, resultTypes,
|
||||||
appendToBlock, opName, loc, resultTypes, operands);
|
operands);
|
||||||
}
|
}
|
||||||
// Oops, no registered op -- create an opaque wrapper so that import can
|
// Oops, no registered op -- create an opaque wrapper so that import can
|
||||||
// still succeed. This helps a common use case of filling out registered ops
|
// still succeed. This helps a common use case of filling out registered ops
|
|
@ -38,35 +38,36 @@ public:
|
||||||
/// for Python code).
|
/// for Python code).
|
||||||
///
|
///
|
||||||
/// Returns a null type on failure and emits a diagnostic.
|
/// Returns a null type on failure and emits a diagnostic.
|
||||||
MlirType
|
MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
|
||||||
getMlirTypeForTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
c10::ScalarType scalarType);
|
||||||
|
|
||||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||||
/// on failure and emits a diagnostic.
|
/// on failure and emits a diagnostic.
|
||||||
MlirType getMlirTypeFromTorchType(
|
MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
MlirLocation loc, const c10::TypePtr& torchType,
|
const c10::TypePtr &torchType,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||||
///
|
///
|
||||||
/// This can differ from the type inferred from the block of a
|
/// This can differ from the type inferred from the block of a
|
||||||
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
||||||
MlirType getFunctionTypeFromSchema(
|
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||||
MlirContext context, const c10::FunctionSchema& schema,
|
const c10::FunctionSchema &schema,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||||
MlirAttribute
|
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
convertTensorToMlirElementsAttr(at::Tensor tensor, MlirLocation loc);
|
MlirLocation loc);
|
||||||
|
|
||||||
MlirAttribute
|
MlirAttribute importAttribute(MlirLocation loc, torch::jit::Node *node,
|
||||||
importAttribute(MlirLocation loc, torch::jit::Node* node, c10::Symbol symbol);
|
c10::Symbol symbol);
|
||||||
|
|
||||||
MlirLocation
|
MlirLocation getMlirLocationFromNode(MlirContext context,
|
||||||
getMlirLocationFromNode(MlirContext context, torch::jit::Node* node);
|
torch::jit::Node *node);
|
||||||
|
|
||||||
std::vector<MlirType> getMlirTypesFromValues(
|
std::vector<MlirType>
|
||||||
MlirLocation loc, c10::ArrayRef<torch::jit::Value*> values,
|
getMlirTypesFromValues(MlirLocation loc,
|
||||||
|
c10::ArrayRef<torch::jit::Value *> values,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
std::vector<MlirValue> adjustStaticInformationForValues(
|
std::vector<MlirValue> adjustStaticInformationForValues(
|
||||||
|
@ -78,9 +79,10 @@ std::vector<MlirValue> adjustStaticInformationForValues(
|
||||||
///
|
///
|
||||||
/// The primary difficulty here is doing the appropriate name munging and
|
/// The primary difficulty here is doing the appropriate name munging and
|
||||||
/// checking if the have a registered op.
|
/// checking if the have a registered op.
|
||||||
MlirOperation createOperationFromSchema(
|
MlirOperation createOperationFromSchema(MlirBlock appendToBlock,
|
||||||
MlirBlock appendToBlock, MlirLocation loc,
|
MlirLocation loc,
|
||||||
const c10::FunctionSchema& schema, c10::ArrayRef<MlirType> resultTypes,
|
const c10::FunctionSchema &schema,
|
||||||
|
c10::ArrayRef<MlirType> resultTypes,
|
||||||
c10::ArrayRef<MlirValue> operands);
|
c10::ArrayRef<MlirValue> operands);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
|
@ -1,30 +1,3 @@
|
||||||
# Static library with core functionality.
|
|
||||||
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
|
||||||
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
|
||||||
add_library(TorchMLIRJITIRImporter STATIC
|
|
||||||
class_annotator.cpp
|
|
||||||
function_importer.cpp
|
|
||||||
node_importer.cpp
|
|
||||||
ivalue_importer.cpp
|
|
||||||
torch_to_mlir_utils.cpp
|
|
||||||
)
|
|
||||||
target_link_libraries(TorchMLIRJITIRImporter
|
|
||||||
TorchMLIRAggregateCAPI
|
|
||||||
${TORCH_LIBRARIES}
|
|
||||||
)
|
|
||||||
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
|
|
||||||
target_include_directories(TorchMLIRJITIRImporter PUBLIC
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
|
||||||
)
|
|
||||||
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
|
||||||
OUTPUT_NAME lib_jit_ir_importer
|
|
||||||
PREFIX ""
|
|
||||||
SUFFIX ".a"
|
|
||||||
CXX_VISIBILITY_PRESET "default"
|
|
||||||
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Separate Pybind MODULE due to issues with a SHARED library.
|
# Separate Pybind MODULE due to issues with a SHARED library.
|
||||||
# https://github.com/llvm/torch-mlir/issues/1154
|
# https://github.com/llvm/torch-mlir/issues/1154
|
||||||
add_library(TorchMLIRJITIRImporterPybind MODULE
|
add_library(TorchMLIRJITIRImporterPybind MODULE
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "class_annotator_pybind.h"
|
#include "class_annotator_pybind.h"
|
||||||
#include "class_annotator.h"
|
#include "jit_ir_importer/class_annotator.h"
|
||||||
|
|
||||||
#include <torch/csrc/Dtype.h>
|
#include <torch/csrc/Dtype.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "import_options_pybind.h"
|
#include "import_options_pybind.h"
|
||||||
#include "import_options.h"
|
#include "jit_ir_importer/import_options.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,9 @@
|
||||||
|
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
#include "function_importer.h"
|
#include "jit_ir_importer/function_importer.h"
|
||||||
#include "ivalue_importer.h"
|
#include "jit_ir_importer/ivalue_importer.h"
|
||||||
#include "mlir_utils.h"
|
#include "jit_ir_importer/mlir_utils.h"
|
||||||
|
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "jit_ir_importer/class_annotator.h"
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue