torch-mlir/frontends/pytorch/csrc/builder/node_importer.cpp

388 lines
14 KiB
C++
Raw Normal View History

//===- node_importer.cpp --------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "node_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "op_builder.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
namespace py = pybind11;
using namespace torch_mlir;
using Value = torch::jit::Value;
using Block = torch::jit::Block;
using Node = torch::jit::Node;
namespace {
class NodeImporter {
public:
NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock);
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
private:
void importPrimNode(Node *node, MlirBlock appendToBlock);
void importKernelCall(Node *node, MlirBlock appendToBlock);
MlirBlock createBlockFor(Block *jitBlock);
void mapValue(Value *jitValue, MlirValue value);
void mapResults(Node *node, MlirOperation operation);
MlirValue lookupMappedValue(Value *jitValue);
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
MlirContext context;
std::unordered_map<Value *, MlirValue> valueMap;
};
} // namespace
void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
TypeMapper typeMapper(context);
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();
if (kind == c10::prim::Constant) {
auto output = node->output();
MlirOperation op;
OpBuilder builder(context);
if (output->type()->cast<c10::NoneType>()) {
op = builder.createNoneConstant(loc);
} else if (output->type()->cast<c10::BoolType>()) {
op = builder.createBoolConstant(
loc, static_cast<bool>(node->i(c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) {
// TODO: Are TorchScript strings bytes or str technically?
// For now, model it as bytes to avoid pledging more than we currently
// model (e.g. no unicode, etc.).
op = builder.createBytesConstant(loc, node->s(c10::attr::value));
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
torch::jit::Function *function = functionType->function();
const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation(
"std.constant", loc,
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
getFunctionTypeFromSchema(context, function->getSchema()),
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
toMlirNamedAttribute(
"value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else {
MlirAttribute valueAttr = importAttribute(loc, node, c10::attr::value);
op = builder.createStdConstant(loc, valueAttr);
}
mlirBlockAppendOwnedOperation(appendToBlock, op);
mapResults(node, op);
return;
}
if (kind == c10::prim::GetAttr) {
MlirType resultType =
typeMapper.mapFromTorchType(loc, node->output()->type());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.GetAttr", loc, resultType,
lookupMappedValues(node->inputs()),
toMlirNamedAttribute("name",
importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation);
return;
}
if (kind == c10::prim::SetAttr) {
createMlirOperationAtEnd(
appendToBlock, "torch.prim.SetAttr", loc,
lookupMappedValues(node->inputs()),
toMlirNamedAttribute("name",
importAttribute(loc, node, c10::attr::name)));
return;
}
if (kind == c10::prim::CallMethod) {
MlirType resultType =
typeMapper.mapFromTorchType(loc, node->output()->type());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc, resultType,
lookupMappedValues(node->inputs()),
toMlirNamedAttribute("name",
importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation);
return;
}
if (kind == c10::prim::Print) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.Print", loc,
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::TupleConstruct) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "basicpy.build_tuple", loc, npcompTupleTypeGet(context),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::ListConstruct) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "basicpy.build_list", loc, npcompListTypeGet(context),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::Loop) {
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd(
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
appendToBlock, "torch.prim.Loop", loc, resultTypes,
lookupMappedValues(node->inputs().slice(0, 2)),
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
loc, appendToBlock),
mlirRegionCreate());
mapResults(node, operation);
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
derefineValues(yieldedValues,
terminatorOperandTypes, loc,
appendToBlock));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
importBlock(node->blocks()[0], createTerminator));
return;
}
if (kind == c10::prim::If) {
// TorchScript will already have an explicit op to determine truthiness. So
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
MlirOperation pred = createMlirOperationAtEnd(
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
lookupMappedValue(node->input()));
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
resultTypes, mlirRegionCreate(), mlirRegionCreate());
mapResults(node, operation);
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
auto createTerminator =
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "scf.yield", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator));
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], createTerminator));
return;
}
if (kind == c10::prim::NumToTensor) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.NumToTensor", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
if (kind == c10::prim::CallFunction) {
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block *calleeEntryBlock =
functionType->function()->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "std.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValue(node->input(0)),
derefineValues(lookupMappedValues(node->inputs().slice(1)),
expectedTypes, loc, appendToBlock));
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
mapResults(node, operation);
return;
}
if (kind == c10::prim::RaiseException) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.RaiseException", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::Uninitialized) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.Uninitialized", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::unchecked_cast) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.unchecked_cast", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::TupleUnpack) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.TupleUnpack", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::TupleIndex) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.TupleIndex", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::ListUnpack) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.ListUnpack", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
2021-03-11 08:41:18 +08:00
if (kind == c10::prim::dtype) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.dtype", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;
msg << "unhandled prim operation: ";
node->print(msg, 0, nullptr);
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
void NodeImporter::importKernelCall(Node *node, MlirBlock appendToBlock) {
TypeMapper typeMapper(context);
MlirLocation loc = getMlirLocationFromNode(context, node);
KernelCallBuilder kcb(context, loc, node->kind().toQualString(),
node->schema());
for (MlirValue value : lookupMappedValues(node->inputs())) {
kcb.addOperand(value);
}
for (MlirType type : getMlirTypesFromValues(loc, node->outputs())) {
kcb.addResultType(type);
}
MlirOperation op = kcb.create();
mlirBlockAppendOwnedOperation(appendToBlock, op);
mapResults(node, op);
}
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
if (node->kind().ns() == c10::namespaces::prim) {
importPrimNode(node, appendToBlock);
return;
}
if (node->maybeSchema()) {
importKernelCall(node, appendToBlock);
return;
}
{
std::stringstream msg;
msg << "unhandled: generic operation: ";
node->print(msg, 0, nullptr);
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
MlirBlock NodeImporter::importBlock(Block *jitBlock,
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
CreateTerminatorFn createTerminator) {
MlirBlock block = createBlockFor(jitBlock);
for (Node *node : jitBlock->nodes()) {
importNode(node, block);
}
Node *returnNode = jitBlock->return_node();
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
createTerminator(lookupMappedValues(returnNode->inputs()), block);
return block;
}
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
Node *paramNode = jitBlock->param_node();
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> blockArgTypes =
getMlirTypesFromValues(loc, paramNode->outputs());
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data());
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value *jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i);
mapValue(jitValue, value);
}
return block;
}
void NodeImporter::mapValue(Value *jitValue, MlirValue value) {
auto it = valueMap.find(jitValue);
(void)it;
assert(it == valueMap.end() && "jitValue has already been mapped");
valueMap[jitValue] = value;
}
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
assert(node->outputs().size() ==
(size_t)mlirOperationGetNumResults(operation));
for (int i = 0, e = node->outputs().size(); i < e; i++) {
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
}
}
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
auto it = valueMap.find(jitValue);
assert(it != valueMap.end() &&
"trying to get mapping for jitValue that is not mapped yet!");
return it->second;
}
std::vector<MlirValue>
NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
std::vector<MlirValue> ret;
for (Value *value : values) {
ret.push_back(lookupMappedValue(value));
}
return ret;
}
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
CreateTerminatorFn createTerminator) {
NodeImporter importer(context);
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
return importer.importBlock(jitBlock, createTerminator);
}