mirror of https://github.com/llvm/torch-mlir
240 lines
8.3 KiB
C++
240 lines
8.3 KiB
C++
//===- node_importer.cpp --------------------------------------------------===//
|
|
//
|
|
// This file is licensed under a pytorch-style license
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "node_importer.h"
|
|
|
|
#include <unordered_map>
|
|
|
|
#include "mlir_utils.h"
|
|
#include "op_builder.h"
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
#include "mlir-c/Diagnostics.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace torch_mlir;
|
|
|
|
using Value = torch::jit::Value;
|
|
using Block = torch::jit::Block;
|
|
using Node = torch::jit::Node;
|
|
|
|
namespace {
|
|
class NodeImporter {
|
|
public:
|
|
NodeImporter(MlirContext context) : context(context) {}
|
|
|
|
void importNode(Node *node, MlirBlock appendToBlock);
|
|
MlirBlock importBlock(Block *jitBlock, const std::string &terminatorOpName);
|
|
|
|
private:
|
|
void importPrimNode(Node *node, MlirBlock appendToBlock);
|
|
void importKernelCall(Node *node, MlirBlock appendToBlock);
|
|
|
|
MlirBlock createBlockFor(Block *jitBlock);
|
|
void mapValue(Value *jitValue, MlirValue value);
|
|
void mapResults(Node *node, MlirOperation operation);
|
|
MlirValue lookupMappedValue(Value *jitValue);
|
|
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
|
|
|
|
MlirContext context;
|
|
std::unordered_map<Value *, MlirValue> valueMap;
|
|
};
|
|
} // namespace
|
|
|
|
void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|
TypeMapper typeMapper(context);
|
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
|
auto kind = node->kind();
|
|
if (kind == c10::prim::Constant) {
|
|
auto output = node->output();
|
|
MlirOperation op;
|
|
OpBuilder builder(context);
|
|
if (output->type()->cast<c10::NoneType>()) {
|
|
op = builder.createNoneConstant(loc);
|
|
} else if (output->type()->cast<c10::BoolType>()) {
|
|
op = builder.createBoolConstant(
|
|
loc, static_cast<bool>(node->i(c10::attr::value)));
|
|
} else if (output->type()->cast<c10::StringType>()) {
|
|
// TODO: Are TorchScript strings bytes or str technically?
|
|
// For now, model it as bytes to avoid pledging more than we currently
|
|
// model (e.g. no unicode, etc.).
|
|
op = builder.createBytesConstant(loc, node->s(c10::attr::value));
|
|
} else {
|
|
MlirAttribute valueAttr = importAttribute(loc, node, c10::attr::value);
|
|
op = builder.createStdConstant(loc, valueAttr);
|
|
}
|
|
mlirBlockAppendOwnedOperation(appendToBlock, op);
|
|
mapResults(node, op);
|
|
return;
|
|
}
|
|
|
|
if (kind == c10::prim::GetAttr) {
|
|
MlirType resultType =
|
|
typeMapper.mapFromTorchType(loc, node->output()->type());
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
appendToBlock, "torch.prim.GetAttr", loc, resultType,
|
|
lookupMappedValues(node->inputs()),
|
|
toMlirNamedAttribute("name",
|
|
importAttribute(loc, node, c10::attr::name)));
|
|
mapResults(node, operation);
|
|
return;
|
|
}
|
|
|
|
if (kind == c10::prim::SetAttr) {
|
|
createMlirOperationAtEnd(
|
|
appendToBlock, "torch.prim.SetAttr", loc,
|
|
lookupMappedValues(node->inputs()),
|
|
toMlirNamedAttribute("name",
|
|
importAttribute(loc, node, c10::attr::name)));
|
|
return;
|
|
}
|
|
|
|
if (kind == c10::prim::CallMethod) {
|
|
MlirType resultType =
|
|
typeMapper.mapFromTorchType(loc, node->output()->type());
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
appendToBlock, "torch.prim.CallMethod", loc, resultType,
|
|
lookupMappedValues(node->inputs()),
|
|
toMlirNamedAttribute("name",
|
|
importAttribute(loc, node, c10::attr::name)));
|
|
mapResults(node, operation);
|
|
return;
|
|
}
|
|
|
|
if (kind == c10::prim::Print) {
|
|
MlirOperation operation =
|
|
createMlirOperationAtEnd(appendToBlock, "torch.prim.Print", loc,
|
|
lookupMappedValues(node->inputs()));
|
|
mapResults(node, operation);
|
|
return;
|
|
}
|
|
|
|
if (kind == c10::prim::If) {
|
|
// TorchScript will already have an explicit op to determine truthiness. So
|
|
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
|
MlirOperation pred = createMlirOperationAtEnd(
|
|
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
|
|
lookupMappedValue(node->input()));
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
|
getMlirTypesFromValues(loc, node->outputs()), mlirRegionCreate(),
|
|
mlirRegionCreate());
|
|
mapResults(node, operation);
|
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
|
importBlock(node->blocks()[0], "scf.yield"));
|
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
|
importBlock(node->blocks()[1], "scf.yield"));
|
|
return;
|
|
}
|
|
|
|
// Unhandled.
|
|
{
|
|
std::stringstream msg;
|
|
msg << "unhandled prim operation: ";
|
|
node->print(msg, 0, nullptr);
|
|
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
|
|
throw mlir_diagnostic_emitted();
|
|
}
|
|
}
|
|
|
|
void NodeImporter::importKernelCall(Node *node, MlirBlock appendToBlock) {
|
|
TypeMapper typeMapper(context);
|
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
|
KernelCallBuilder kcb(context, loc, node->kind().toQualString(),
|
|
node->schema());
|
|
for (MlirValue value : lookupMappedValues(node->inputs())) {
|
|
kcb.addOperand(value);
|
|
}
|
|
for (MlirType type : getMlirTypesFromValues(loc, node->outputs())) {
|
|
kcb.addResultType(type);
|
|
}
|
|
MlirOperation op = kcb.create();
|
|
mlirBlockAppendOwnedOperation(appendToBlock, op);
|
|
mapResults(node, op);
|
|
}
|
|
|
|
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|
if (node->kind().ns() == c10::namespaces::prim) {
|
|
importPrimNode(node, appendToBlock);
|
|
return;
|
|
}
|
|
if (node->maybeSchema()) {
|
|
importKernelCall(node, appendToBlock);
|
|
return;
|
|
}
|
|
|
|
{
|
|
std::stringstream msg;
|
|
msg << "unhandled: generic operation: ";
|
|
node->print(msg, 0, nullptr);
|
|
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
|
|
throw mlir_diagnostic_emitted();
|
|
}
|
|
}
|
|
|
|
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
|
const std::string &terminatorOpName) {
|
|
MlirBlock block = createBlockFor(jitBlock);
|
|
for (Node *node : jitBlock->nodes()) {
|
|
importNode(node, block);
|
|
}
|
|
Node *returnNode = jitBlock->return_node();
|
|
createMlirOperationAtEnd(block, terminatorOpName,
|
|
getMlirLocationFromNode(context, returnNode),
|
|
lookupMappedValues(returnNode->inputs()));
|
|
return block;
|
|
}
|
|
|
|
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
|
Node *paramNode = jitBlock->param_node();
|
|
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
|
std::vector<MlirType> blockArgTypes =
|
|
getMlirTypesFromValues(loc, paramNode->outputs());
|
|
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data());
|
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
|
Value *jitValue = paramNode->outputs()[i];
|
|
MlirValue value = mlirBlockGetArgument(block, i);
|
|
mapValue(jitValue, value);
|
|
}
|
|
return block;
|
|
}
|
|
|
|
void NodeImporter::mapValue(Value *jitValue, MlirValue value) {
|
|
auto it = valueMap.find(jitValue);
|
|
(void)it;
|
|
assert(it == valueMap.end() && "jitValue has already been mapped");
|
|
valueMap[jitValue] = value;
|
|
}
|
|
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
|
|
assert(node->outputs().size() ==
|
|
(size_t)mlirOperationGetNumResults(operation));
|
|
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
|
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
|
}
|
|
}
|
|
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
|
|
auto it = valueMap.find(jitValue);
|
|
assert(it != valueMap.end() &&
|
|
"trying to get mapping for jitValue that is not mapped yet!");
|
|
return it->second;
|
|
}
|
|
std::vector<MlirValue>
|
|
NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
|
std::vector<MlirValue> ret;
|
|
for (Value *value : values) {
|
|
ret.push_back(lookupMappedValue(value));
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
|
const std::string &terminatorOpName) {
|
|
NodeImporter importer(context);
|
|
return importer.importBlock(jitBlock, terminatorOpName);
|
|
}
|