torch-mlir/frontends/pytorch/csrc/builder/graph_importer.cpp

39 lines
1.4 KiB
C++
Raw Normal View History

//===- graph_importer.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation torch_mlir::importGraphAsFuncOp(MlirContext context,
torch::jit::Graph *graph,
const std::string &name) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirAttribute typeAttr =
mlirTypeAttrGet(getFunctionTypeFromBlock(context, graph->block()));
MlirAttribute symNameAttr = mlirStringAttrGet(context, toMlirStringRef(name));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", typeAttr),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
MlirBlock block = importBlock(context, graph->block(), "std.return");
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}