2021-01-28 08:35:44 +08:00
|
|
|
//===- ivalue_importer.cpp ------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "ivalue_importer.h"
|
2021-02-20 08:21:21 +08:00
|
|
|
#include "class_annotator.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
#include "graph_importer.h"
|
|
|
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
#include "mlir_utils.h"
|
|
|
|
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
#include "mlir-c/Diagnostics.h"
|
|
|
|
#include "npcomp-c/Types.h"
|
|
|
|
|
|
|
|
using namespace torch_mlir;
|
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
// Hashing functionality for IValue's.
|
|
|
|
//
|
|
|
|
// What we want here is a strict object identity hash. This is different from
|
|
|
|
// what Python usually treats as hashing, which is a deep equality hash. In
|
|
|
|
// Python terms, what we want here is a hash of `id(x)` -- unfortunately, IValue
|
|
|
|
// is not uniformly heap allocated the way a `PyObject*` is, so special handling
|
|
|
|
// is needed. At the time of this writing, there seem to be two different
|
|
|
|
// implementations, neither of which is exactly what we want.
|
|
|
|
//
|
|
|
|
// - c10::IValue::hash static method
|
|
|
|
// - Problem: Doesn't handle certain data types, in particular objects (which
|
|
|
|
// modules are a special case of) and lists/dicts. This makes sense when
|
|
|
|
// reflecting the Python semantics.
|
|
|
|
// - c10::WeakIValue::hash method
|
|
|
|
// - Problem: it literally just returns the bits of the "union" as an int.
|
|
|
|
// This seems to read uninitialized bits for the bool variant.
|
|
|
|
//
|
|
|
|
// We use the `c10::IValue::hash` static method with special cases for data
|
|
|
|
// types that need their identity to be handled specially. `c10::IValue::hash`
|
|
|
|
// seems to be implemented in a principled way following the Python semantics,
|
|
|
|
// which is compatible with the semantics we want (for the subset it doesn't
|
|
|
|
// throw an error on).
|
|
|
|
namespace {
|
|
|
|
struct IValueHasher {
|
|
|
|
size_t operator()(const c10::IValue &ivalue) const {
|
|
|
|
if (ivalue.isObject() || ivalue.isList()) {
|
|
|
|
return std::hash<const void *>()(
|
|
|
|
static_cast<const void *>(ivalue.internalToPointer()));
|
|
|
|
}
|
|
|
|
|
|
|
|
return c10::IValue::hash(ivalue);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// TODO: The implementation of isSameIdentity looks vulnerable to malloc reusing
|
|
|
|
// the same memory block (if this hash function is used in an online setting,
|
|
|
|
// such as when tracing). Can we do better?
|
|
|
|
namespace {
|
|
|
|
struct IValueEq {
|
|
|
|
bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const {
|
|
|
|
return lhs.isSameIdentity(rhs);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
namespace {
|
|
|
|
/// Helper class for holding state during recursive IValue import.
|
|
|
|
///
|
|
|
|
/// The intended usage pattern of this class is to construct it then call
|
2021-02-06 09:57:38 +08:00
|
|
|
/// `importIValue`.
|
|
|
|
///
|
|
|
|
/// The `importIValue` method can be called more than once, and values are
|
|
|
|
/// unified *by object identity*. For types isomorphic to Python builtin types
|
|
|
|
/// the behavior is what you would expect from `id(x)`.
|
2021-01-28 08:35:44 +08:00
|
|
|
///
|
2021-02-06 09:57:38 +08:00
|
|
|
/// For tensors, object identity is a little tricky. As background, at::Tensor
|
|
|
|
/// basically has 4 parts:
|
|
|
|
/// - at::Tensor which is a smart pointer to at::TensorImpl
|
|
|
|
/// - at::TensorImpl which holds sizes/strides/etc. and points to at::Storage
|
|
|
|
/// - the address of the at::TensorImpl is the identity of the tensor.
|
|
|
|
/// - at::Storage which is a smart pointer to at::StorageImpl
|
|
|
|
/// - at::StorageImpl which is a low-level buffer
|
|
|
|
/// - the address of the at::StorageImpl is the identity of the "storage".
|
|
|
|
///
|
|
|
|
/// Multiple different tensors can share the same underlying storage. We
|
|
|
|
/// currently import tensors by identity and emit errors in the case of tensors
|
|
|
|
/// with different identity but sharing the same storage. This is done because
|
|
|
|
/// correctly modeling the many ways that tensors can overlap and alias when
|
|
|
|
/// they share storage is difficult. Example hard cases are weird
|
|
|
|
/// strides/offsets that overlap, and even cases where the data types mismatch
|
|
|
|
/// (PyTorch allows this!).
|
2021-01-28 08:35:44 +08:00
|
|
|
class IValueImporter {
|
|
|
|
public:
|
2021-02-20 08:21:21 +08:00
|
|
|
IValueImporter(MlirBlock importBlock, MlirContext context,
|
|
|
|
ClassAnnotator &annotator)
|
|
|
|
: importBlock(importBlock), context(context), typeMapper(context),
|
|
|
|
annotator(annotator) {}
|
2021-01-28 08:35:44 +08:00
|
|
|
|
|
|
|
MlirValue importIValue(c10::IValue value);
|
|
|
|
|
|
|
|
private:
|
2021-02-06 09:57:38 +08:00
|
|
|
MlirValue rawImportIValue(c10::IValue value);
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirValue importModule(torch::jit::Module jitModule);
|
2021-02-20 08:21:21 +08:00
|
|
|
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
|
|
|
const MethodAnnotation &methodAnnotation);
|
2021-02-18 03:28:51 +08:00
|
|
|
void importClassType(c10::ClassType *classType);
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
void importCompilationUnit(torch::jit::CompilationUnit *cu);
|
2021-01-28 08:35:44 +08:00
|
|
|
|
|
|
|
MlirBlock importBlock;
|
|
|
|
MlirContext context;
|
|
|
|
TypeMapper typeMapper;
|
2021-02-20 08:21:21 +08:00
|
|
|
ClassAnnotator &annotator;
|
2021-02-06 09:57:38 +08:00
|
|
|
|
|
|
|
// Map tracking already-imported values.
|
|
|
|
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
|
|
|
|
// The unique compilation unit that is shared by all modules reachable
|
|
|
|
// from the root ivalue being imported.
|
|
|
|
// It basically contains a symbol table of functions which are referenced from
|
|
|
|
// e.g. methods (the function names are meaningful and match with Python's
|
|
|
|
// module hierarchy, with the exception of `__main__` being replaced with
|
|
|
|
// `__torch__`).
|
|
|
|
torch::jit::CompilationUnit *compilationUnit = nullptr;
|
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
// Used to detect potentially aliasing tensors.
|
|
|
|
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
2021-02-18 03:28:51 +08:00
|
|
|
// The set of ClassType's that have already been imported.
|
|
|
|
//
|
|
|
|
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
|
|
|
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
|
|
|
// them with the MlirOperation that they import into.
|
|
|
|
std::unordered_set<c10::ClassType *> classTypes;
|
2021-02-06 09:57:38 +08:00
|
|
|
// The stack of attribute names we have traversed to reach the current IValue.
|
|
|
|
// Used for diagnostics.
|
|
|
|
std::vector<std::string> attributeNameStack;
|
|
|
|
// The root module encountered during recursive IValue traversal.
|
|
|
|
// Used for diagnostics.
|
|
|
|
// Note that the "top-level" object being imported can in theory be a list
|
|
|
|
// of modules, so this is populated when our recursive traversal enters a
|
|
|
|
// module not enclosed in any other module, and unset after our recursive
|
|
|
|
// traversal exits the module.
|
|
|
|
c10::optional<std::string> rootModuleName;
|
2021-01-28 08:35:44 +08:00
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
|
|
|
if (!maybeName) {
|
|
|
|
throw std::invalid_argument("cannot import unnamed module");
|
|
|
|
}
|
|
|
|
std::string moduleTypeName = maybeName->qualifiedName();
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
// If this is the first time we are encountering a module, import the
|
|
|
|
// compilation unit.
|
|
|
|
importCompilationUnit(currentModule._ivalue()->compilation_unit().get());
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
// Ensure the class type has been imported.
|
|
|
|
importClassType(currentModule.type().get());
|
|
|
|
|
|
|
|
MlirOperation nnModule = createMlirOperation(
|
|
|
|
"torch.nn_module", loc,
|
|
|
|
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
|
|
|
mlirRegionCreate());
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
|
|
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
|
|
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
if (!rootModuleName.has_value()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
rootModuleName = moduleTypeName;
|
2021-02-06 09:57:38 +08:00
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
|
|
|
const std::vector<c10::ClassAttribute> &classAttributes =
|
|
|
|
currentModule.type()->getAttributes();
|
|
|
|
assert(slots.size() == classAttributes.size() &&
|
|
|
|
"mismatch between object and type!");
|
|
|
|
for (int i = 0, e = slots.size(); i < e; i++) {
|
|
|
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
2021-02-06 09:57:38 +08:00
|
|
|
attributeNameStack.push_back(classAttribute.getName());
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirValue slotValue = importIValue(slots[i]);
|
|
|
|
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
|
|
|
createMlirOperationAtEnd(
|
2021-02-18 03:28:51 +08:00
|
|
|
nnModuleBody, "torch.slot", loc, slotValue,
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name", mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classAttribute.getName()))));
|
2021-02-06 09:57:38 +08:00
|
|
|
attributeNameStack.pop_back();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rootModuleName.has_value()) {
|
|
|
|
rootModuleName = c10::nullopt;
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
|
|
|
mlirBlockInsertOwnedOperationBefore(
|
|
|
|
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
|
|
|
return mlirOperationGetResult(nnModule, 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
|
2021-02-06 09:57:38 +08:00
|
|
|
auto it = valueMap.find(ivalue);
|
|
|
|
if (it != valueMap.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
// Reject potentially aliased tensors.
|
|
|
|
if (ivalue.isTensor()) {
|
|
|
|
c10::StorageImpl *storageImpl =
|
|
|
|
ivalue.toTensor().storage().unsafeGetStorageImpl();
|
|
|
|
if (!seenStorageImpls.insert(storageImpl).second) {
|
|
|
|
std::stringstream msg;
|
|
|
|
msg << "Unhandled tensor that shares storage with another tensor.";
|
|
|
|
if (rootModuleName) {
|
|
|
|
msg << "\nFound at path '<root>."
|
|
|
|
<< c10::QualifiedName(attributeNameStack).qualifiedName()
|
|
|
|
<< "' from root object '" << *rootModuleName << "'";
|
|
|
|
}
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MlirValue value = rawImportIValue(ivalue);
|
|
|
|
valueMap[ivalue] = value;
|
|
|
|
return value;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
2021-01-28 08:35:44 +08:00
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
|
|
|
if (ivalue.isBool()) {
|
|
|
|
MlirType type = npcompBoolTypeGet(context);
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "basicpy.bool_constant", loc, type,
|
|
|
|
toMlirNamedAttribute("value",
|
|
|
|
mlirBoolAttrGet(context, ivalue.toBool())));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
|
|
|
if (ivalue.isDouble()) {
|
|
|
|
MlirType type = mlirF64TypeGet(context);
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "basicpy.numeric_constant", loc, type,
|
|
|
|
toMlirNamedAttribute(
|
|
|
|
"value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble())));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
|
|
|
if (ivalue.isInt()) {
|
|
|
|
MlirType type = mlirIntegerTypeGet(context, 64);
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "basicpy.numeric_constant", loc, type,
|
|
|
|
toMlirNamedAttribute("value",
|
|
|
|
mlirIntegerAttrGet(type, ivalue.toInt())));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-02-06 09:57:38 +08:00
|
|
|
if (ivalue.isList()) {
|
|
|
|
c10::List<c10::IValue> list = ivalue.toList();
|
|
|
|
std::vector<MlirValue> elems;
|
|
|
|
for (const c10::IValue &elem : list) {
|
|
|
|
elems.push_back(importIValue(elem));
|
|
|
|
}
|
|
|
|
MlirOperation operation =
|
|
|
|
createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc,
|
|
|
|
npcompListTypeGet(context), elems);
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-02-26 04:14:00 +08:00
|
|
|
if (ivalue.isTuple()) {
|
|
|
|
auto list = ivalue.toTuple()->elements();
|
|
|
|
std::vector<MlirValue> elems;
|
|
|
|
for (const c10::IValue &elem : list) {
|
|
|
|
elems.push_back(importIValue(elem));
|
|
|
|
}
|
|
|
|
MlirOperation operation =
|
|
|
|
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
|
|
|
npcompTupleTypeGet(context), elems);
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
if (ivalue.isTensor()) {
|
|
|
|
at::Tensor tensor = ivalue.toTensor().contiguous();
|
|
|
|
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
|
|
|
MlirOperation constant = createMlirOperationAtEnd(
|
|
|
|
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
|
|
|
|
toMlirNamedAttribute("value", denseElements));
|
|
|
|
MlirOperation ndarray = createMlirOperationAtEnd(
|
|
|
|
importBlock, "numpy.create_array_from_tensor", loc,
|
|
|
|
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
|
|
|
|
mlirOperationGetResult(constant, 0));
|
|
|
|
return mlirOperationGetResult(ndarray, 0);
|
|
|
|
}
|
|
|
|
if (ivalue.isModule()) {
|
|
|
|
return importModule(ivalue.toModule());
|
|
|
|
}
|
2021-02-18 21:56:15 +08:00
|
|
|
if (ivalue.isNone()) {
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
std::stringstream msg;
|
|
|
|
msg << "Unsupported ivalue: " << ivalue;
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
2021-02-20 08:21:21 +08:00
|
|
|
MlirBlock classTypeBody,
|
|
|
|
const MethodAnnotation &methodAnnotation) {
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
// The function's name becomes the MLIR symbol table name of the imported func
|
|
|
|
// when we import the compilation unit.
|
|
|
|
const std::string &symName = function->qualname().qualifiedName();
|
|
|
|
MlirAttribute functionSymbolRef =
|
|
|
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
c10::optional<MlirNamedAttribute> isPrivate;
|
|
|
|
if (!methodAnnotation.isExported) {
|
|
|
|
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
createMlirOperationAtEnd(
|
2021-02-18 03:28:51 +08:00
|
|
|
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name",
|
|
|
|
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
|
|
|
|
void IValueImporter::importClassType(c10::ClassType *classType) {
|
|
|
|
if (!classTypes.insert(classType).second) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
|
|
|
MlirOperation op = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.class_type", loc, mlirRegionCreate(),
|
|
|
|
toMlirNamedAttribute(
|
|
|
|
"sym_name",
|
|
|
|
mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
|
|
|
MlirRegion region = mlirOperationGetRegion(op, 0);
|
|
|
|
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
|
|
|
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
ClassAnnotation &classAnnotation =
|
|
|
|
annotator.getOrCreateClassAnnotation(classType);
|
|
|
|
|
|
|
|
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
|
|
|
|
const auto &classAttributes = classType->getAttributes();
|
|
|
|
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
|
|
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
|
|
|
c10::optional<MlirNamedAttribute> isPrivate;
|
|
|
|
if (!attributeAnnotations[i].isExported) {
|
|
|
|
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
createMlirOperationAtEnd(
|
|
|
|
classTypeBody, "torch.attr", loc,
|
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name", mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classAttribute.getName()))),
|
|
|
|
toMlirNamedAttribute("type",
|
|
|
|
mlirTypeAttrGet(typeMapper.mapFromTorchType(
|
2021-02-20 08:21:21 +08:00
|
|
|
loc, classAttribute.getType()))),
|
|
|
|
isPrivate);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
|
|
|
|
const auto &methods = classType->methods();
|
|
|
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
|
|
|
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
|
|
|
}
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|
|
|
if (compilationUnit == nullptr) {
|
|
|
|
compilationUnit = cu;
|
|
|
|
} else {
|
|
|
|
// All sorts of stuff is connected to the compilation unit, such as
|
|
|
|
// c10::ClassType's (owned by the compilation unit), c10::FunctionType
|
|
|
|
// (which holds a pointer to a torch::jit::Function in the compilation
|
|
|
|
// unit), load-bearing symbol table names of functions, etc.
|
|
|
|
//
|
|
|
|
// It doesn't seem to be defined how multiple compilation units semantically
|
|
|
|
// connect with each other, and it doesn't seem to happen either (though
|
|
|
|
// structurally at the C++ level nothing prevents it), so make it an error.
|
|
|
|
if (compilationUnit != cu) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"found two compilation units while importing");
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (torch::jit::Function *function : cu->get_functions()) {
|
|
|
|
MlirOperation func = importGraphAsFuncOp(
|
|
|
|
context, function->graph().get(), function->qualname().qualifiedName());
|
|
|
|
// For IValue importing, the logical linkage structure of the module
|
|
|
|
// is determined by the object graph.
|
|
|
|
//
|
|
|
|
// The functions' symbol names are thus irrelevant to the module's
|
|
|
|
// externally visible characteristics, so mark them all as private.
|
|
|
|
//
|
|
|
|
// These functions may be referenced by the object graph, which can make
|
|
|
|
// them reachable from the exernally visible characteristics of the module,
|
|
|
|
// but they cannot be intrinsically externally visible.
|
|
|
|
mlirOperationSetAttributeByName(
|
|
|
|
func, toMlirStringRef("sym_visibility"),
|
|
|
|
mlirStringAttrGet(context, toMlirStringRef("private")));
|
|
|
|
mlirBlockInsertOwnedOperationBefore(
|
|
|
|
importBlock, mlirBlockGetTerminator(importBlock), func);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
2021-02-20 08:21:21 +08:00
|
|
|
MlirContext context, ClassAnnotator &annotator) {
|
2021-01-28 08:35:44 +08:00
|
|
|
// When debugging module importing, it can be useful to dump as so:
|
|
|
|
// if (ivalue.isModule())
|
2021-02-19 09:10:17 +08:00
|
|
|
// ivalue.toModule().dump(true, false, false);
|
2021-02-20 08:21:21 +08:00
|
|
|
IValueImporter importer(block, context, annotator);
|
2021-01-28 08:35:44 +08:00
|
|
|
importer.importIValue(ivalue);
|
|
|
|
}
|