torch-mlir/frontends/pytorch/csrc/builder/graph_importer.cpp

304 lines
10 KiB
C++

//===- graph_importer.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
namespace py = pybind11;
using namespace torch_mlir;
//------------------------------------------------------------------------------
// GraphImporter::NodeScope implementation
//------------------------------------------------------------------------------
// A scope of Graph Value * to corresponding MlirValue. Scopes nest
// region-wise. Note that in PyTorch, the thing called 'Block' is analagous
// to a capturing MLIR region.
class GraphImporter::NodeScope {
public:
NodeScope() = default;
NodeScope(NodeScope *prev) : prev(prev) {}
void bindValue(torch::jit::Value *torchValue, MlirValue value);
MlirValue findValue(torch::jit::Value *torchValue);
MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue);
private:
llvm::DenseMap<torch::jit::Value *, MlirValue> valueMap;
NodeScope *prev = nullptr;
};
void GraphImporter::NodeScope::bindValue(torch::jit::Value *torchValue,
MlirValue value) {
assert(valueMap.count(torchValue) == 0 && "duplicate torch Value bind");
valueMap[torchValue] = value;
}
MlirValue GraphImporter::NodeScope::findValue(torch::jit::Value *torchValue) {
auto foundIt = valueMap.find(torchValue);
if (foundIt == valueMap.end()) {
if (prev)
return prev->findValue(torchValue);
else
return {nullptr};
}
return foundIt->second;
}
MlirValue
GraphImporter::NodeScope::findRequiredValue(MlirLocation loc,
torch::jit::Value *torchValue) {
MlirValue value = findValue(torchValue);
if (mlirValueIsNull(value)) {
std::stringstream msg;
msg << "internal error: unmapped torch value: %" << torchValue->debugName();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
return value;
}
//------------------------------------------------------------------------------
// GraphImporter::NodeImporter implementation
//------------------------------------------------------------------------------
/// Helper class to import a torch::jit::Node into an MLIR function.
/// This class primarily exists to eliminate the need for large lists of
/// carried arguments related to doing the import.
class GraphImporter::NodeImporter {
public:
NodeImporter(torch::jit::Node *node, GraphImporter &parent,
FuncBuilder *funcBuilder, MlirBlock block, MlirOperation ip,
NodeScope *scope);
void importNode();
void importReturnOp();
private:
MlirContext context() { return parent.context(); }
void importPrimNode();
MlirAttribute importValueAttribute();
torch::jit::Node *node;
GraphImporter &parent;
FuncBuilder *funcBuilder;
MlirBlock block;
MlirOperation ip;
NodeScope *scope;
MlirLocation loc;
};
GraphImporter::NodeImporter::NodeImporter(torch::jit::Node *node,
GraphImporter &parent,
FuncBuilder *funcBuilder,
MlirBlock block, MlirOperation ip,
NodeScope *scope)
: node(node), parent(parent), funcBuilder(funcBuilder), block(block),
ip(ip), scope(scope) {
loc = parent.extractCallstackLoc(node);
}
void GraphImporter::NodeImporter::importNode() {
// Prim namespace handled specially.
auto kind = node->kind();
if (kind.ns() == c10::namespaces::prim) {
importPrimNode();
return;
}
// Generic import.
auto funcSchema = node->maybeSchema();
if (funcSchema) {
KernelCallBuilder kcb(context(), loc, kind.toQualString(), *funcSchema);
for (auto *input : node->inputs()) {
kcb.addOperand(scope->findRequiredValue(loc, input));
}
for (auto *output : node->outputs()) {
MlirType type =
parent.type_mapper().mapFromTorchType(loc, output->type());
if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted();
}
kcb.addResultType(type);
}
MlirOperation op = kcb.create();
mlirBlockInsertOwnedOperationBefore(block, ip, op);
// Map results.
for (auto it : llvm::enumerate(node->outputs())) {
scope->bindValue(it.value(), mlirOperationGetResult(op, it.index()));
}
return;
}
// No soup for you. Not exactly sure when this can happen.
{
std::stringstream msg;
msg << "unhandled: generic operation " << kind.toDisplayString();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
void GraphImporter::NodeImporter::importReturnOp() {
OperationStateHolder s("std.return", loc);
llvm::SmallVector<MlirValue, 4> returnsValues;
for (auto *input : node->inputs()) {
returnsValues.push_back(scope->findRequiredValue(loc, input));
}
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
mlirBlockInsertOwnedOperationBefore(block, ip, s.createOperation());
}
void GraphImporter::NodeImporter::importPrimNode() {
auto kind = node->kind();
if (kind == c10::prim::Constant) {
auto output = node->output();
MlirAttribute valueAttr = importValueAttribute();
MlirValue constValue = funcBuilder->getGeneralConstant(loc, valueAttr);
scope->bindValue(output, constValue);
return;
}
// Unhandled.
{
std::stringstream msg;
msg << "unhandled: prim operation " << kind.toDisplayString();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
MlirAttribute GraphImporter::NodeImporter::importValueAttribute() {
using torch::jit::AttributeKind;
auto s = c10::attr::value;
auto kind = node->kindOf(s);
switch (kind) {
case AttributeKind::i:
// TODO: This should be a signed int once we have a constant op that can
// do that.
return mlirIntegerAttrGet(mlirIntegerTypeGet(context(), 64), node->i(s));
break;
case AttributeKind::f:
return mlirFloatAttrDoubleGet(context(), mlirF64TypeGet(context()),
node->f(s));
break;
default: {
std::stringstream msg;
msg << "unhandled: value attribute kind " << toString(kind);
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
}
//------------------------------------------------------------------------------
// GraphImporter implementation
//------------------------------------------------------------------------------
GraphImporter::GraphImporter(std::shared_ptr<torch::jit::Graph> graph,
MlirMappingOptions mappingOptions)
: graph(std::move(graph)), mappingOptions(std::move(mappingOptions)) {}
std::shared_ptr<GraphImporter> GraphImporter::forPythonJitFunc(
torch::jit::Function *function,
GraphImporter::MlirMappingOptions mappingOptions) {
// Disallow an attempt to compile a native function.
if (!function->isGraphFunction()) {
throw std::invalid_argument(
"Expected a torch.jit.ScriptFunction with a graph");
}
auto graph = function->graph();
if (!mappingOptions.genericFuncName) {
mappingOptions.genericFuncName = function->name() + "$generic";
}
if (!mappingOptions.funcName) {
mappingOptions.funcName = function->name() + "$generic";
}
return std::make_shared<GraphImporter>(graph, std::move(mappingOptions));
}
void GraphImporter::initialize() {
defaultLoc = mlirLocationUnknownGet(context());
// There is not a callstack associated with the graph so, try to grab
// a location from the first node that has one as a better than nothing
// thing.
// TODO: This doesn't actually seem to be working. Investigate when more
// examples are built out.
for (auto *node : graph->nodes()) {
MlirLocation nodeLoc = extractCallstackLoc(node, /*useDefault=*/false);
if (nodeLoc.ptr) {
defaultLoc = nodeLoc;
break;
}
}
// Map inputs.
MlirLocation inputLoc = extractCallstackLoc(graph->param_node());
for (const auto &input : graph->inputs()) {
MlirType t = type_mapper().mapFromTorchType(inputLoc, input->type());
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("could not convert function input type");
genericFuncArgTypes.push_back(t);
}
// Map outputs.
MlirLocation outputLoc = extractCallstackLoc(graph->return_node());
for (const auto &output : graph->outputs()) {
MlirType t = type_mapper().mapFromTorchType(outputLoc, output->type());
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("could not convert function output type");
genericFuncReturnTypes.push_back(t);
}
}
void GraphImporter::importGenericFunc() {
auto funcBuilder = FuncBuilder::createFunction(
mappingOptions.inserter, defaultLoc, *mappingOptions.genericFuncName,
genericFuncArgTypes);
funcBuilder->rewriteFuncReturnTypes(genericFuncReturnTypes);
MlirBlock entryBlock = funcBuilder->getEntryBlock();
// Bind inputs.
NodeScope scope;
for (const auto &it : llvm::enumerate(graph->inputs())) {
MlirValue value = mlirBlockGetArgument(entryBlock, it.index());
scope.bindValue(it.value(), value);
}
// Walk body nodes.
for (auto *node : graph->nodes()) {
NodeImporter importer{
node, *this, funcBuilder.get(), entryBlock, /*ip=*/{nullptr}, &scope};
importer.importNode();
}
// Map the output node to a return.
auto *outputNode = graph->return_node();
NodeImporter returnImporter{outputNode, *this,
funcBuilder.get(), entryBlock,
/*ip=*/{nullptr}, &scope};
returnImporter.importReturnOp();
}
MlirLocation GraphImporter::extractCallstackLoc(torch::jit::Node *node,
bool useDefault) {
auto flc = node->sourceRange().file_line_col();
if (flc) {
const std::string &file = std::get<0>(*flc);
int line = std::get<1>(*flc);
int col = std::get<2>(*flc);
return mlirLocationFileLineColGet(context(), file.c_str(), line, col);
}
return useDefault ? defaultLoc : MlirLocation{nullptr};
}