torch-mlir/frontends/pytorch/csrc/builder/ivalue_importer.cpp

529 lines
22 KiB
C++
Raw Normal View History

//===- ivalue_importer.cpp ------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "ivalue_importer.h"
#include "class_annotator.h"
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
#include "function_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "caffe2/core/scope_guard.h"
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
#include "ATen/native/quantized/cpu/packed_params.h"
using namespace torch_mlir;
// Hashing functionality for IValue's.
//
// What we want here is a strict object identity hash. This is different from
// what Python usually treats as hashing, which is a deep equality hash. In
// Python terms, what we want here is a hash of `id(x)` -- unfortunately, IValue
// is not uniformly heap allocated the way a `PyObject*` is, so special handling
// is needed. At the time of this writing, there seem to be two different
// implementations, neither of which is exactly what we want.
//
// - c10::IValue::hash static method
// - Problem: Doesn't handle certain data types, in particular objects (which
// modules are a special case of) and lists/dicts. This makes sense when
// reflecting the Python semantics.
// - c10::WeakIValue::hash method
// - Problem: it literally just returns the bits of the "union" as an int.
// This seems to read uninitialized bits for the bool variant.
//
// We use the `c10::IValue::hash` static method with special cases for data
// types that need their identity to be handled specially. `c10::IValue::hash`
// seems to be implemented in a principled way following the Python semantics,
// which is compatible with the semantics we want (for the subset it doesn't
// throw an error on).
namespace {
struct IValueHasher {
size_t operator()(const c10::IValue &ivalue) const {
if (ivalue.isObject() || ivalue.isList()) {
return std::hash<const void *>()(
static_cast<const void *>(ivalue.internalToPointer()));
}
return c10::IValue::hash(ivalue);
}
};
} // namespace
// TODO: The implementation of isSameIdentity looks vulnerable to malloc reusing
// the same memory block (if this hash function is used in an online setting,
// such as when tracing). Can we do better?
namespace {
struct IValueEq {
bool operator()(const c10::IValue &lhs, const c10::IValue &rhs) const {
return lhs.isSameIdentity(rhs);
}
};
} // namespace
namespace {
/// Helper class for holding state during recursive IValue import.
///
/// The intended usage pattern of this class is to construct it then call
/// `importIValue`.
///
/// The `importIValue` method can be called more than once, and values are
/// unified *by object identity*. For types isomorphic to Python builtin types
/// the behavior is what you would expect from `id(x)`.
///
/// For tensors, object identity is a little tricky. As background, at::Tensor
/// basically has 4 parts:
/// - at::Tensor which is a smart pointer to at::TensorImpl
/// - at::TensorImpl which holds sizes/strides/etc. and points to at::Storage
/// - the address of the at::TensorImpl is the identity of the tensor.
/// - at::Storage which is a smart pointer to at::StorageImpl
/// - at::StorageImpl which is a low-level buffer
/// - the address of the at::StorageImpl is the identity of the "storage".
///
/// Multiple different tensors can share the same underlying storage. We
/// currently import tensors by identity and emit errors in the case of tensors
/// with different identity but sharing the same storage. This is done because
/// correctly modeling the many ways that tensors can overlap and alias when
/// they share storage is difficult. Example hard cases are weird
/// strides/offsets that overlap, and even cases where the data types mismatch
/// (PyTorch allows this!).
class IValueImporter {
public:
IValueImporter(MlirBlock importBlock, MlirContext context,
ClassAnnotator &annotator)
: importBlock(importBlock), context(context), typeMapper(context),
annotator(annotator) {}
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
MlirValue importIValue(c10::IValue ivalue);
private:
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
MlirValue rawImportIValue(c10::IValue ivalue);
MlirValue importTensor(c10::IValue ivalue);
MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation);
void importClassType(c10::ClassType *classType);
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
void importCompilationUnit(torch::jit::CompilationUnit *cu);
MlirBlock importBlock;
MlirContext context;
TypeMapper typeMapper;
ClassAnnotator &annotator;
// Map tracking already-imported values.
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
// The unique compilation unit that is shared by all modules reachable
// from the root ivalue being imported.
// It basically contains a symbol table of functions which are referenced from
// e.g. methods (the function names are meaningful and match with Python's
// module hierarchy, with the exception of `__main__` being replaced with
// `__torch__`).
torch::jit::CompilationUnit *compilationUnit = nullptr;
// Used to detect potentially aliasing tensors.
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
// The set of ClassType's that have already been imported.
//
// ClassType's are referenced via their `classType->name()->qualifiedName()`
// string (as an MLIR symbol name) so we don't need to keep a map associating
// them with the MlirOperation that they import into.
std::unordered_set<c10::ClassType *> classTypes;
// The stack of attribute names we have traversed to reach the current IValue.
// Used for diagnostics.
std::vector<std::string> attributeNameStack;
// The root module encountered during recursive IValue traversal.
// Used for diagnostics.
// Note that the "top-level" object being imported can in theory be a list
// of modules, so this is populated when our recursive traversal enters a
// module not enclosed in any other module, and unset after our recursive
// traversal exits the module.
c10::optional<std::string> rootModuleName;
};
} // namespace
MlirValue
IValueImporter::importModule(torch::jit::Module currentModule) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
if (!maybeName) {
throw std::invalid_argument("cannot import unnamed module");
}
std::string moduleTypeName = maybeName->qualifiedName();
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
// If this is the first time we are encountering a module, import the
// compilation unit.
importCompilationUnit(currentModule._ivalue()->compilation_unit().get());
// Ensure the class type has been imported.
importClassType(currentModule.type().get());
MlirOperation nnModule = createMlirOperation(
"torch.nn_module", loc,
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
auto inserter = caffe2::MakeGuard([&]() {
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
});
if (!rootModuleName.has_value()) {
rootModuleName = moduleTypeName;
}
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
const std::vector<c10::ClassAttribute> &classAttributes =
currentModule.type()->getAttributes();
assert(slots.size() == classAttributes.size() &&
"mismatch between object and type!");
for (int i = 0, e = slots.size(); i < e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i];
attributeNameStack.push_back(classAttribute.getName());
MlirValue slotValue = importIValue(slots[i]);
// TODO: Is it necessary to track whether an attribute is a "parameter"?
createMlirOperationAtEnd(
nnModuleBody, "torch.slot", loc, slotValue,
toMlirNamedAttribute(
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))));
attributeNameStack.pop_back();
}
if (rootModuleName.has_value()) {
rootModuleName = c10::nullopt;
}
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
return mlirOperationGetResult(nnModule, 0);
}
MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
auto it = valueMap.find(ivalue);
if (it != valueMap.end()) {
return it->second;
}
// Reject potentially aliased tensors.
if (ivalue.isTensor()) {
c10::StorageImpl *storageImpl =
ivalue.toTensor().storage().unsafeGetStorageImpl();
if (!seenStorageImpls.insert(storageImpl).second) {
std::stringstream msg;
msg << "Unhandled tensor that shares storage with another tensor.";
if (rootModuleName) {
msg << "\nFound at path '<root>."
<< c10::QualifiedName(attributeNameStack).qualifiedName()
<< "' from root object '" << *rootModuleName << "'";
}
throw std::invalid_argument(msg.str());
}
}
MlirValue value = rawImportIValue(ivalue);
valueMap[ivalue] = value;
return value;
}
MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
if (ivalue.isBool()) {
MlirType type = npcompBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bool_constant", loc, type,
toMlirNamedAttribute("value",
mlirBoolAttrGet(context, ivalue.toBool())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isDouble()) {
MlirType type = mlirF64TypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.numeric_constant", loc, type,
toMlirNamedAttribute(
"value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isInt()) {
MlirType type = mlirIntegerTypeGet(context, 64);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.numeric_constant", loc, type,
toMlirNamedAttribute("value",
mlirIntegerAttrGet(type, ivalue.toInt())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isList()) {
c10::List<c10::IValue> list = ivalue.toList();
std::vector<MlirValue> elems;
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc,
npcompListTypeGet(context), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTuple()) {
auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> elems;
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
npcompTupleTypeGet(context), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
return importTensor(ivalue);
}
if (ivalue.isModule()) {
return importModule(ivalue.toModule());
}
if (ivalue.isString()) {
MlirType type = npcompBytesTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bytes_constant", loc, type,
toMlirNamedAttribute(
"value",
mlirStringAttrGet(context,
toMlirStringRef(ivalue.toString()->string()))));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isNone()) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
return mlirOperationGetResult(operation, 0);
}
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
if (ivalue.isCustomClass()) {
if (ivalue.type().get() ==
c10::getCustomClassType<c10::intrusive_ptr<LinearPackedParamsBase>>()
.get()) {
c10::intrusive_ptr<LinearPackedParamsBase> linearParams =
ivalue.toCustomClass<LinearPackedParamsBase>();
at::Tensor weight;
c10::optional<at::Tensor> bias;
std::tie(weight, bias) = linearParams->unpack();
MlirValue weightValue = importIValue(c10::IValue(weight));
c10::optional<MlirValue> biasValue = c10::nullopt;
if (bias.has_value()) {
biasValue = importIValue(c10::IValue(*bias));
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.linear_params.create", loc,
npcompLinearParamsTypeGet(context), weightValue, biasValue);
return mlirOperationGetResult(operation, 0);
}
}
std::stringstream msg;
msg << "Unsupported ivalue: " << ivalue;
throw std::invalid_argument(msg.str());
}
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
assert(ivalue.isTensor() && "expected a tensor!");
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
// Import the bulk tensor representation.
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
MlirOperation constant = createMlirOperationAtEnd(
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(constant, 0);
// Construct the complete tensor value. This is trivial for most tensors, but
// for quantized tensors (and probably sparse too, TBD) there is more for us
// to do.
MlirValue tensorValue;
if (tensor.is_quantized()) {
// Note that Torch models quantization in a type-erased way. So we don't
// make an effort here to do any special static modeling. If desired, later
// compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type()), {nullptr});
if (tensor.qscheme() == c10::kPerTensorAffine) {
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
MlirOperation quantizedTensor = createMlirOperationAtEnd(
importBlock, "torch.per_tensor_affine.create", loc,
quantizedTensorType, tensorReprValue, qScale, zeroPoint);
tensorValue = mlirOperationGetResult(quantizedTensor, 0);
} else {
std::stringstream msg;
msg << "Unsupported quantization scheme '"
<< c10::toString(tensor.qscheme()) << "' for tensor: " << ivalue;
throw std::invalid_argument(msg.str());
}
} else {
tensorValue = tensorReprValue;
}
// Convert the tensor to ndarray to match Torch's default-mutable semantics.
MlirOperation ndarray = createMlirOperationAtEnd(
importBlock, "numpy.create_array_from_tensor", loc,
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
tensorValue);
return mlirOperationGetResult(ndarray, 0);
}
void IValueImporter::importMethod(torch::jit::Function *function,
MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation) {
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
// The function's name becomes the MLIR symbol table name of the imported func
// when we import the compilation unit.
const std::string &symName = function->qualname().qualifiedName();
MlirAttribute functionSymbolRef =
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName));
c10::optional<MlirNamedAttribute> isPrivate;
if (!methodAnnotation.isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
}
createMlirOperationAtEnd(
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
toMlirNamedAttribute("function", functionSymbolRef), isPrivate);
}
void IValueImporter::importClassType(c10::ClassType *classType) {
if (!classTypes.insert(classType).second) {
return;
}
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
MlirOperation op = createMlirOperationAtEnd(
importBlock, "torch.class_type", loc, mlirRegionCreate(),
toMlirNamedAttribute(
"sym_name",
mlirStringAttrGet(
context, toMlirStringRef(classType->name()->qualifiedName()))));
MlirRegion region = mlirOperationGetRegion(op, 0);
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
ClassAnnotation &classAnnotation =
annotator.getOrCreateClassAnnotation(classType);
const auto &attributeAnnotations = classAnnotation.getAttributeAnnotations();
const auto &classAttributes = classType->getAttributes();
for (int i = 0, e = classAttributes.size(); i != e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i];
c10::optional<MlirNamedAttribute> isPrivate;
if (!attributeAnnotations[i].isExported) {
isPrivate = toMlirNamedAttribute("isPrivate", mlirUnitAttrGet(context));
}
createMlirOperationAtEnd(
classTypeBody, "torch.attr", loc,
toMlirNamedAttribute(
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type",
mlirTypeAttrGet(typeMapper.mapFromTorchType(
loc, classAttribute.getType()))),
isPrivate);
}
const auto &methodAnnotations = classAnnotation.getMethodAnnotations();
const auto &methods = classType->methods();
for (int i = 0, e = methods.size(); i != e; i++) {
importMethod(methods[i], classTypeBody, methodAnnotations[i]);
}
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
}
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
if (compilationUnit == nullptr) {
compilationUnit = cu;
} else {
// All sorts of stuff is connected to the compilation unit, such as
// c10::ClassType's (owned by the compilation unit), c10::FunctionType
// (which holds a pointer to a torch::jit::Function in the compilation
// unit), load-bearing symbol table names of functions, etc.
//
// It doesn't seem to be defined how multiple compilation units semantically
// connect with each other, and it doesn't seem to happen either (though
// structurally at the C++ level nothing prevents it), so make it an error.
if (compilationUnit != cu) {
throw std::invalid_argument(
"found two compilation units while importing");
}
return;
}
for (torch::jit::Function *function : cu->get_functions()) {
// Useful for debugging errors in free functions that end up being
// unused. These can be missing when round-tripping through the on-disk
// format, even though they still cause import issues when importing
// through the larger Python session where they originate.
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
// std::cerr << *function->graph();
MethodAnnotation *annotation =
annotator.getMethodAnnotationForFunction(function);
MlirOperation func = importJitFunctionAsFuncOp(
context, function, [&](int argIndex) -> MlirAttribute {
if (!annotation || !annotation->argAnnotations.has_value()) {
return {nullptr};
}
auto &shape = annotation->argAnnotations.value()[argIndex].shape;
auto &dtype = annotation->argAnnotations.value()[argIndex].dtype;
// TODO: Handle unranked tensors and tensors with unknown dtype (but
// possibly known ranks/sizes).
if (!shape || !dtype) {
return {nullptr};
}
auto typeBound = npcompNdArrayTypeGetRanked(
shape->size(), shape->data(),
TypeMapper(context).mapFromTorchScalarType(
mlirLocationUnknownGet(context), *dtype));
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
"torch.type_bound", mlirTypeAttrGet(typeBound));
return mlirDictionaryAttrGet(context, 1, &typeBoundAttr);
});
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
// For IValue importing, the logical linkage structure of the module
// is determined by the object graph.
//
// The functions' symbol names are thus irrelevant to the module's
// externally visible characteristics, so mark them all as private.
//
// These functions may be referenced by the object graph, which can make
// them reachable from the exernally visible characteristics of the module,
// but they cannot be intrinsically externally visible.
mlirOperationSetAttributeByName(
func, toMlirStringRef("sym_visibility"),
mlirStringAttrGet(context, toMlirStringRef("private")));
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), func);
}
}
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirContext context, ClassAnnotator &annotator) {
// When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule())
// ivalue.toModule().dump(true, false, false);
IValueImporter importer(block, context, annotator);
importer.importIValue(ivalue);
}