2021-01-28 08:35:44 +08:00
|
|
|
//===- ivalue_importer.cpp ------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "ivalue_importer.h"
|
2021-02-20 08:21:21 +08:00
|
|
|
#include "class_annotator.h"
|
2021-03-02 09:24:15 +08:00
|
|
|
#include "function_importer.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
#include "mlir_utils.h"
|
|
|
|
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
#include "mlir-c/Diagnostics.h"
|
2021-06-15 05:13:59 +08:00
|
|
|
#include "npcomp-c/TorchTypes.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
|
2021-05-19 03:48:22 +08:00
|
|
|
#include "caffe2/core/scope_guard.h"
|
2021-05-20 02:40:48 +08:00
|
|
|
#include "ATen/native/quantized/cpu/packed_params.h"
|
2021-05-19 03:48:22 +08:00
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
using namespace torch_mlir;
|
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
// Hashing functionality for IValue's.
|
|
|
|
//
|
|
|
|
// What we want here is a strict object identity hash. This is different from
|
|
|
|
// what Python usually treats as hashing, which is a deep equality hash. In
|
|
|
|
// Python terms, what we want here is a hash of `id(x)` -- unfortunately, IValue
|
|
|
|
// is not uniformly heap allocated the way a `PyObject*` is, so special handling
|
|
|
|
// is needed. At the time of this writing, there seem to be two different
|
|
|
|
// implementations, neither of which is exactly what we want.
|
|
|
|
//
|
|
|
|
// - c10::IValue::hash static method
|
|
|
|
// - Problem: Doesn't handle certain data types, in particular objects (which
|
|
|
|
// modules are a special case of) and lists/dicts. This makes sense when
|
|
|
|
// reflecting the Python semantics.
|
|
|
|
// - c10::WeakIValue::hash method
|
|
|
|
// - Problem: it literally just returns the bits of the "union" as an int.
|
|
|
|
// This seems to read uninitialized bits for the bool variant.
|
|
|
|
//
|
|
|
|
// We use the `c10::IValue::hash` static method with special cases for data
|
|
|
|
// types that need their identity to be handled specially. `c10::IValue::hash`
|
|
|
|
// seems to be implemented in a principled way following the Python semantics,
|
|
|
|
// which is compatible with the semantics we want (for the subset it doesn't
|
|
|
|
// throw an error on).
|
|
|
|
namespace {
|
|
|
|
struct IValueHasher {
|
|
|
|
size_t operator()(const c10::IValue &ivalue) const {
|
|
|
|
if (ivalue.isObject() || ivalue.isList()) {
|
|
|
|
return std::hash<const void *>()(
|
|
|
|
static_cast<const void *>(ivalue.internalToPointer()));
|
|
|
|
}
|
|
|
|
|
|
|
|
return c10::IValue::hash(ivalue);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// TODO: The implementation of isSameIdentity looks vulnerable to malloc reusing
|
|
|
|
// the same memory block (if this hash function is used in an online setting,
|
|
|
|
// such as when tracing). Can we do better?
|
|
|
|
namespace {
|
|
|
|
struct IValueEq {
|
|
|
|
bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const {
|
|
|
|
return lhs.isSameIdentity(rhs);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
namespace {
|
|
|
|
/// Helper class for holding state during recursive IValue import.
|
|
|
|
///
|
|
|
|
/// The intended usage pattern of this class is to construct it then call
|
2021-02-06 09:57:38 +08:00
|
|
|
/// `importIValue`.
|
|
|
|
///
|
|
|
|
/// The `importIValue` method can be called more than once, and values are
|
|
|
|
/// unified *by object identity*. For types isomorphic to Python builtin types
|
|
|
|
/// the behavior is what you would expect from `id(x)`.
|
2021-01-28 08:35:44 +08:00
|
|
|
///
|
2021-02-06 09:57:38 +08:00
|
|
|
/// For tensors, object identity is a little tricky. As background, at::Tensor
|
|
|
|
/// basically has 4 parts:
|
|
|
|
/// - at::Tensor which is a smart pointer to at::TensorImpl
|
|
|
|
/// - at::TensorImpl which holds sizes/strides/etc. and points to at::Storage
|
|
|
|
/// - the address of the at::TensorImpl is the identity of the tensor.
|
|
|
|
/// - at::Storage which is a smart pointer to at::StorageImpl
|
|
|
|
/// - at::StorageImpl which is a low-level buffer
|
|
|
|
/// - the address of the at::StorageImpl is the identity of the "storage".
|
|
|
|
///
|
|
|
|
/// Multiple different tensors can share the same underlying storage. We
|
|
|
|
/// currently import tensors by identity and emit errors in the case of tensors
|
|
|
|
/// with different identity but sharing the same storage. This is done because
|
|
|
|
/// correctly modeling the many ways that tensors can overlap and alias when
|
|
|
|
/// they share storage is difficult. Example hard cases are weird
|
|
|
|
/// strides/offsets that overlap, and even cases where the data types mismatch
|
|
|
|
/// (PyTorch allows this!).
|
2021-01-28 08:35:44 +08:00
|
|
|
class IValueImporter {
|
|
|
|
public:
|
2021-02-20 08:21:21 +08:00
|
|
|
IValueImporter(MlirBlock importBlock, MlirContext context,
|
|
|
|
ClassAnnotator &annotator)
|
|
|
|
: importBlock(importBlock), context(context), typeMapper(context),
|
|
|
|
annotator(annotator) {}
|
2021-01-28 08:35:44 +08:00
|
|
|
|
2021-05-20 02:40:48 +08:00
|
|
|
MlirValue importIValue(c10::IValue ivalue);
|
2021-01-28 08:35:44 +08:00
|
|
|
|
|
|
|
private:
|
2021-05-20 02:40:48 +08:00
|
|
|
MlirValue rawImportIValue(c10::IValue ivalue);
|
|
|
|
MlirValue importTensor(c10::IValue ivalue);
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirValue importModule(torch::jit::Module jitModule);
|
2021-02-20 08:21:21 +08:00
|
|
|
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
|
|
|
const MethodAnnotation &methodAnnotation);
|
2021-02-18 03:28:51 +08:00
|
|
|
void importClassType(c10::ClassType *classType);
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
void importCompilationUnit(torch::jit::CompilationUnit *cu);
|
2021-01-28 08:35:44 +08:00
|
|
|
|
|
|
|
MlirBlock importBlock;
|
|
|
|
MlirContext context;
|
|
|
|
TypeMapper typeMapper;
|
2021-02-20 08:21:21 +08:00
|
|
|
ClassAnnotator &annotator;
|
2021-02-06 09:57:38 +08:00
|
|
|
|
|
|
|
// Map tracking already-imported values.
|
|
|
|
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
|
|
|
|
// The unique compilation unit that is shared by all modules reachable
|
|
|
|
// from the root ivalue being imported.
|
|
|
|
// It basically contains a symbol table of functions which are referenced from
|
|
|
|
// e.g. methods (the function names are meaningful and match with Python's
|
|
|
|
// module hierarchy, with the exception of `__main__` being replaced with
|
|
|
|
// `__torch__`).
|
|
|
|
torch::jit::CompilationUnit *compilationUnit = nullptr;
|
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
// Used to detect potentially aliasing tensors.
|
|
|
|
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
2021-02-18 03:28:51 +08:00
|
|
|
// The set of ClassType's that have already been imported.
|
|
|
|
//
|
|
|
|
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
|
|
|
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
|
|
|
// them with the MlirOperation that they import into.
|
|
|
|
std::unordered_set<c10::ClassType *> classTypes;
|
2021-02-06 09:57:38 +08:00
|
|
|
// The stack of attribute names we have traversed to reach the current IValue.
|
|
|
|
// Used for diagnostics.
|
|
|
|
std::vector<std::string> attributeNameStack;
|
|
|
|
// The root module encountered during recursive IValue traversal.
|
|
|
|
// Used for diagnostics.
|
|
|
|
// Note that the "top-level" object being imported can in theory be a list
|
|
|
|
// of modules, so this is populated when our recursive traversal enters a
|
|
|
|
// module not enclosed in any other module, and unset after our recursive
|
|
|
|
// traversal exits the module.
|
|
|
|
c10::optional<std::string> rootModuleName;
|
2021-01-28 08:35:44 +08:00
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-05-19 03:48:22 +08:00
|
|
|
MlirValue
|
|
|
|
IValueImporter::importModule(torch::jit::Module currentModule) {
|
2021-01-28 08:35:44 +08:00
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
|
|
|
if (!maybeName) {
|
|
|
|
throw std::invalid_argument("cannot import unnamed module");
|
|
|
|
}
|
|
|
|
std::string moduleTypeName = maybeName->qualifiedName();
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
// If this is the first time we are encountering a module, import the
|
|
|
|
// compilation unit.
|
|
|
|
importCompilationUnit(currentModule._ivalue()->compilation_unit().get());
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
// Ensure the class type has been imported.
|
|
|
|
importClassType(currentModule.type().get());
|
|
|
|
|
|
|
|
MlirOperation nnModule = createMlirOperation(
|
|
|
|
"torch.nn_module", loc,
|
2021-06-15 05:13:59 +08:00
|
|
|
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
2021-02-18 03:28:51 +08:00
|
|
|
mlirRegionCreate());
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
|
|
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
|
|
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
2021-05-19 03:48:22 +08:00
|
|
|
auto inserter = caffe2::MakeGuard([&]() {
|
|
|
|
mlirBlockInsertOwnedOperationBefore(
|
|
|
|
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
|
|
|
});
|
2021-01-28 08:35:44 +08:00
|
|
|
|
2021-02-06 09:57:38 +08:00
|
|
|
if (!rootModuleName.has_value()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
rootModuleName = moduleTypeName;
|
2021-02-06 09:57:38 +08:00
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
|
|
|
const std::vector<c10::ClassAttribute> &classAttributes =
|
|
|
|
currentModule.type()->getAttributes();
|
|
|
|
assert(slots.size() == classAttributes.size() &&
|
|
|
|
"mismatch between object and type!");
|
|
|
|
for (int i = 0, e = slots.size(); i < e; i++) {
|
|
|
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
2021-02-06 09:57:38 +08:00
|
|
|
attributeNameStack.push_back(classAttribute.getName());
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirValue slotValue = importIValue(slots[i]);
|
|
|
|
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
|
|
|
createMlirOperationAtEnd(
|
2021-02-18 03:28:51 +08:00
|
|
|
nnModuleBody, "torch.slot", loc, slotValue,
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name", mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classAttribute.getName()))));
|
2021-02-06 09:57:38 +08:00
|
|
|
attributeNameStack.pop_back();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rootModuleName.has_value()) {
|
|
|
|
rootModuleName = c10::nullopt;
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
|
|
|
return mlirOperationGetResult(nnModule, 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
|
2021-02-06 09:57:38 +08:00
|
|
|
auto it = valueMap.find(ivalue);
|
|
|
|
if (it != valueMap.end()) {
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
// Reject potentially aliased tensors.
|
|
|
|
if (ivalue.isTensor()) {
|
|
|
|
c10::StorageImpl *storageImpl =
|
|
|
|
ivalue.toTensor().storage().unsafeGetStorageImpl();
|
|
|
|
if (!seenStorageImpls.insert(storageImpl).second) {
|
|
|
|
std::stringstream msg;
|
|
|
|
msg << "Unhandled tensor that shares storage with another tensor.";
|
|
|
|
if (rootModuleName) {
|
|
|
|
msg << "\nFound at path '<root>."
|
|
|
|
<< c10::QualifiedName(attributeNameStack).qualifiedName()
|
|
|
|
<< "' from root object '" << *rootModuleName << "'";
|
|
|
|
}
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MlirValue value = rawImportIValue(ivalue);
|
|
|
|
valueMap[ivalue] = value;
|
|
|
|
return value;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
2021-01-28 08:35:44 +08:00
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
|
|
|
if (ivalue.isBool()) {
|
2021-06-16 07:47:53 +08:00
|
|
|
MlirType type = npcompTorchBoolTypeGet(context);
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
2021-06-16 07:47:53 +08:00
|
|
|
importBlock, "torch.constant.bool", loc, type,
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute("value",
|
|
|
|
mlirBoolAttrGet(context, ivalue.toBool())));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
|
|
|
if (ivalue.isDouble()) {
|
|
|
|
MlirType type = mlirF64TypeGet(context);
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
2021-06-16 03:42:51 +08:00
|
|
|
importBlock, "torch.constant.float", loc, type,
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble())));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
|
|
|
if (ivalue.isInt()) {
|
2021-06-17 06:53:15 +08:00
|
|
|
MlirType type = npcompTorchIntTypeGet(context);
|
2021-01-28 08:35:44 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
2021-06-16 03:42:51 +08:00
|
|
|
importBlock, "torch.constant.int", loc, type,
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute("value",
|
2021-06-17 06:53:15 +08:00
|
|
|
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64),
|
|
|
|
ivalue.toInt())));
|
2021-01-28 08:35:44 +08:00
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-02-06 09:57:38 +08:00
|
|
|
if (ivalue.isList()) {
|
|
|
|
c10::List<c10::IValue> list = ivalue.toList();
|
|
|
|
std::vector<MlirValue> elems;
|
|
|
|
for (const c10::IValue &elem : list) {
|
|
|
|
elems.push_back(importIValue(elem));
|
|
|
|
}
|
2021-06-15 05:13:59 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.prim.ListConstruct", loc,
|
|
|
|
npcompTorchListTypeGet(
|
|
|
|
typeMapper.mapFromTorchType(loc, list.elementType())),
|
|
|
|
elems);
|
2021-02-06 09:57:38 +08:00
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-02-26 04:14:00 +08:00
|
|
|
if (ivalue.isTuple()) {
|
|
|
|
auto list = ivalue.toTuple()->elements();
|
2021-06-15 09:06:38 +08:00
|
|
|
std::vector<MlirValue> operands;
|
|
|
|
std::vector<MlirType> types;
|
2021-02-26 04:14:00 +08:00
|
|
|
for (const c10::IValue &elem : list) {
|
2021-06-15 09:06:38 +08:00
|
|
|
MlirValue operand = importIValue(elem);
|
|
|
|
operands.push_back(operand);
|
|
|
|
types.push_back(mlirValueGetType(operand));
|
2021-02-26 04:14:00 +08:00
|
|
|
}
|
2021-06-15 09:06:38 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.prim.TupleConstruct", loc,
|
|
|
|
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
|
2021-02-26 04:14:00 +08:00
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
if (ivalue.isTensor()) {
|
2021-05-20 02:40:48 +08:00
|
|
|
return importTensor(ivalue);
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
|
|
|
if (ivalue.isModule()) {
|
|
|
|
return importModule(ivalue.toModule());
|
|
|
|
}
|
2021-03-05 05:08:50 +08:00
|
|
|
if (ivalue.isString()) {
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
2021-06-15 23:29:06 +08:00
|
|
|
importBlock, "torch.constant.str", loc,
|
|
|
|
npcompTorchStringTypeGet(context),
|
2021-03-05 05:08:50 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"value",
|
|
|
|
mlirStringAttrGet(context,
|
|
|
|
toMlirStringRef(ivalue.toString()->string()))));
|
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-02-18 21:56:15 +08:00
|
|
|
if (ivalue.isNone()) {
|
2021-06-15 02:36:10 +08:00
|
|
|
MlirOperation operation =
|
|
|
|
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
|
|
|
npcompTorchNoneTypeGet(context));
|
2021-02-18 21:56:15 +08:00
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
2021-05-20 02:40:48 +08:00
|
|
|
if (ivalue.isCustomClass()) {
|
|
|
|
if (ivalue.type().get() ==
|
|
|
|
c10::getCustomClassType<c10::intrusive_ptr<LinearPackedParamsBase>>()
|
|
|
|
.get()) {
|
|
|
|
c10::intrusive_ptr<LinearPackedParamsBase> linearParams =
|
|
|
|
ivalue.toCustomClass<LinearPackedParamsBase>();
|
|
|
|
at::Tensor weight;
|
|
|
|
c10::optional<at::Tensor> bias;
|
|
|
|
std::tie(weight, bias) = linearParams->unpack();
|
|
|
|
MlirValue weightValue = importIValue(c10::IValue(weight));
|
|
|
|
c10::optional<MlirValue> biasValue = c10::nullopt;
|
|
|
|
if (bias.has_value()) {
|
|
|
|
biasValue = importIValue(c10::IValue(*bias));
|
|
|
|
}
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.linear_params.create", loc,
|
2021-06-15 05:13:59 +08:00
|
|
|
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
2021-05-20 02:40:48 +08:00
|
|
|
return mlirOperationGetResult(operation, 0);
|
|
|
|
}
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
std::stringstream msg;
|
|
|
|
msg << "Unsupported ivalue: " << ivalue;
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
|
2021-05-20 02:40:48 +08:00
|
|
|
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|
|
|
assert(ivalue.isTensor() && "expected a tensor!");
|
|
|
|
|
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
|
|
|
// Import the bulk tensor representation.
|
|
|
|
at::Tensor tensor = ivalue.toTensor().contiguous();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
|
|
|
MlirOperation tensorOp =
|
|
|
|
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
2021-06-15 05:13:59 +08:00
|
|
|
npcompTorchNonValueTensorTypeGetFromShaped(
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
mlirAttributeGetType(denseElements)),
|
|
|
|
toMlirNamedAttribute("value", denseElements));
|
|
|
|
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
2021-05-20 02:40:48 +08:00
|
|
|
|
|
|
|
// Construct the complete tensor value. This is trivial for most tensors, but
|
|
|
|
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
|
|
|
// to do.
|
|
|
|
MlirValue tensorValue;
|
|
|
|
if (tensor.is_quantized()) {
|
|
|
|
// Note that Torch models quantization in a type-erased way. So we don't
|
|
|
|
// make an effort here to do any special static modeling. If desired, later
|
|
|
|
// compiler stages that are building a statically modeled quantization
|
|
|
|
// representation will need to convert this to their representation.
|
|
|
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
2021-06-15 05:13:59 +08:00
|
|
|
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
context, shape.size(), shape.data(),
|
|
|
|
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
2021-05-20 02:40:48 +08:00
|
|
|
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
|
|
|
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
|
|
|
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
|
|
|
MlirOperation quantizedTensor = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.per_tensor_affine.create", loc,
|
|
|
|
quantizedTensorType, tensorReprValue, qScale, zeroPoint);
|
|
|
|
tensorValue = mlirOperationGetResult(quantizedTensor, 0);
|
|
|
|
} else {
|
|
|
|
std::stringstream msg;
|
|
|
|
msg << "Unsupported quantization scheme '"
|
|
|
|
<< c10::toString(tensor.qscheme()) << "' for tensor: " << ivalue;
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
tensorValue = tensorReprValue;
|
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return tensorValue;
|
2021-05-20 02:40:48 +08:00
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
2021-02-20 08:21:21 +08:00
|
|
|
MlirBlock classTypeBody,
|
|
|
|
const MethodAnnotation &methodAnnotation) {
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
// The function's name becomes the MLIR symbol table name of the imported func
|
|
|
|
// when we import the compilation unit.
|
|
|
|
const std::string &symName = function->qualname().qualifiedName();
|
|
|
|
MlirAttribute functionSymbolRef =
|
|
|
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
c10::optional<MlirNamedAttribute> isPrivate;
|
|
|
|
if (!methodAnnotation.isExported) {
|
|
|
|
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
createMlirOperationAtEnd(
|
2021-02-18 03:28:51 +08:00
|
|
|
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
|
2021-01-28 08:35:44 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name",
|
|
|
|
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
|
|
|
|
void IValueImporter::importClassType(c10::ClassType *classType) {
|
|
|
|
if (!classTypes.insert(classType).second) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Can we do better?
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
|
|
|
|
MlirOperation op = createMlirOperationAtEnd(
|
|
|
|
importBlock, "torch.class_type", loc, mlirRegionCreate(),
|
|
|
|
toMlirNamedAttribute(
|
|
|
|
"sym_name",
|
|
|
|
mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
|
|
|
MlirRegion region = mlirOperationGetRegion(op, 0);
|
|
|
|
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
|
|
|
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
ClassAnnotation &classAnnotation =
|
|
|
|
annotator.getOrCreateClassAnnotation(classType);
|
|
|
|
|
|
|
|
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
|
|
|
|
const auto &classAttributes = classType->getAttributes();
|
|
|
|
for (int i = 0, e = classAttributes.size(); i != e; i++) {
|
|
|
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
|
|
|
c10::optional<MlirNamedAttribute> isPrivate;
|
|
|
|
if (!attributeAnnotations[i].isExported) {
|
|
|
|
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
|
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
createMlirOperationAtEnd(
|
|
|
|
classTypeBody, "torch.attr", loc,
|
|
|
|
toMlirNamedAttribute(
|
|
|
|
"name", mlirStringAttrGet(
|
|
|
|
context, toMlirStringRef(classAttribute.getName()))),
|
|
|
|
toMlirNamedAttribute("type",
|
|
|
|
mlirTypeAttrGet(typeMapper.mapFromTorchType(
|
2021-02-20 08:21:21 +08:00
|
|
|
loc, classAttribute.getType()))),
|
|
|
|
isPrivate);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
|
|
|
|
const auto &methods = classType->methods();
|
|
|
|
for (int i = 0, e = methods.size(); i != e; i++) {
|
|
|
|
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
|
|
|
}
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|
|
|
if (compilationUnit == nullptr) {
|
|
|
|
compilationUnit = cu;
|
|
|
|
} else {
|
|
|
|
// All sorts of stuff is connected to the compilation unit, such as
|
|
|
|
// c10::ClassType's (owned by the compilation unit), c10::FunctionType
|
|
|
|
// (which holds a pointer to a torch::jit::Function in the compilation
|
|
|
|
// unit), load-bearing symbol table names of functions, etc.
|
|
|
|
//
|
|
|
|
// It doesn't seem to be defined how multiple compilation units semantically
|
|
|
|
// connect with each other, and it doesn't seem to happen either (though
|
|
|
|
// structurally at the C++ level nothing prevents it), so make it an error.
|
|
|
|
if (compilationUnit != cu) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"found two compilation units while importing");
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (torch::jit::Function *function : cu->get_functions()) {
|
2021-04-22 06:07:15 +08:00
|
|
|
// Useful for debugging errors in free functions that end up being
|
|
|
|
// unused. These can be missing when round-tripping through the on-disk
|
|
|
|
// format, even though they still cause import issues when importing
|
|
|
|
// through the larger Python session where they originate.
|
|
|
|
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
|
|
|
|
// std::cerr << *function->graph();
|
2021-03-31 07:11:41 +08:00
|
|
|
MethodAnnotation *annotation =
|
|
|
|
annotator.getMethodAnnotationForFunction(function);
|
|
|
|
MlirOperation func = importJitFunctionAsFuncOp(
|
|
|
|
context, function, [&](int argIndex) -> MlirAttribute {
|
|
|
|
if (!annotation || !annotation->argAnnotations.has_value()) {
|
|
|
|
return {nullptr};
|
|
|
|
}
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
c10::optional<std::vector<int64_t>> &maybeShape =
|
|
|
|
annotation->argAnnotations.value()[argIndex].shape;
|
|
|
|
c10::optional<c10::ScalarType> &maybeDtype =
|
|
|
|
annotation->argAnnotations.value()[argIndex].dtype;
|
|
|
|
bool hasValueSemantics =
|
|
|
|
annotation->argAnnotations.value()[argIndex].hasValueSemantics;
|
|
|
|
|
2021-03-31 07:11:41 +08:00
|
|
|
// TODO: Handle unranked tensors and tensors with unknown dtype (but
|
|
|
|
// possibly known ranks/sizes).
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (!maybeShape || !maybeDtype) {
|
2021-03-31 07:11:41 +08:00
|
|
|
return {nullptr};
|
|
|
|
}
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
|
|
|
|
std::vector<int64_t> shape = *maybeShape;
|
|
|
|
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
|
|
|
mlirLocationUnknownGet(context), *maybeDtype);
|
|
|
|
MlirType typeBound;
|
|
|
|
if (hasValueSemantics) {
|
2021-06-15 05:13:59 +08:00
|
|
|
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
|
|
|
shape.data(), dtype);
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
} else {
|
2021-06-15 05:13:59 +08:00
|
|
|
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
|
|
|
shape.data(), dtype);
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-03-31 07:11:41 +08:00
|
|
|
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
|
|
|
"torch.type_bound", mlirTypeAttrGet(typeBound));
|
|
|
|
return mlirDictionaryAttrGet(context, 1, &typeBoundAttr);
|
|
|
|
});
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
// For IValue importing, the logical linkage structure of the module
|
|
|
|
// is determined by the object graph.
|
|
|
|
//
|
|
|
|
// The functions' symbol names are thus irrelevant to the module's
|
|
|
|
// externally visible characteristics, so mark them all as private.
|
|
|
|
//
|
|
|
|
// These functions may be referenced by the object graph, which can make
|
|
|
|
// them reachable from the exernally visible characteristics of the module,
|
|
|
|
// but they cannot be intrinsically externally visible.
|
|
|
|
mlirOperationSetAttributeByName(
|
|
|
|
func, toMlirStringRef("sym_visibility"),
|
|
|
|
mlirStringAttrGet(context, toMlirStringRef("private")));
|
|
|
|
mlirBlockInsertOwnedOperationBefore(
|
|
|
|
importBlock, mlirBlockGetTerminator(importBlock), func);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
2021-02-20 08:21:21 +08:00
|
|
|
MlirContext context, ClassAnnotator &annotator) {
|
2021-01-28 08:35:44 +08:00
|
|
|
// When debugging module importing, it can be useful to dump as so:
|
|
|
|
// if (ivalue.isModule())
|
2021-02-19 09:10:17 +08:00
|
|
|
// ivalue.toModule().dump(true, false, false);
|
2021-02-20 08:21:21 +08:00
|
|
|
IValueImporter importer(block, context, annotator);
|
2021-01-28 08:35:44 +08:00
|
|
|
importer.importIValue(ivalue);
|
|
|
|
}
|