torch-mlir/frontends/pytorch/csrc/builder/function_importer.cpp

57 lines
2.1 KiB
C++
Raw Normal View History

Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
//===- function_importer.cpp ----------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "function_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation
torch_mlir::importJitFunctionAsFuncOp(MlirContext context,
torch::jit::Function *function) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirType functionType =
getFunctionTypeFromSchema(context, function->getSchema());
// Use the function's qualified name from the compilation unit.
// This is a stable linkage name that matches Python module lookup
// conventions (see compilation unit import in IValueImporter for more details
// on qualified names).
MlirAttribute symNameAttr = mlirStringAttrGet(
context, toMlirStringRef(function->qualname().qualifiedName()));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(),
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
std::vector<MlirType> resultTypes;
for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e;
i++) {
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
}
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "std.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
MlirBlock block =
importBlock(context, function->graph()->block(), createTerminator);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}