2021-02-02 09:59:42 +08:00
|
|
|
//===- node_importer.cpp --------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "node_importer.h"
|
|
|
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
#include "mlir_utils.h"
|
|
|
|
#include "op_builder.h"
|
|
|
|
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
#include "mlir-c/Diagnostics.h"
|
2021-02-19 09:10:17 +08:00
|
|
|
#include "npcomp-c/Types.h"
|
2021-02-02 09:59:42 +08:00
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
using namespace torch_mlir;
|
|
|
|
|
|
|
|
using Value = torch::jit::Value;
|
|
|
|
using Block = torch::jit::Block;
|
|
|
|
using Node = torch::jit::Node;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class NodeImporter {
|
|
|
|
public:
|
|
|
|
NodeImporter(MlirContext context) : context(context) {}
|
|
|
|
|
|
|
|
void importNode(Node *node, MlirBlock appendToBlock);
|
2021-03-02 09:24:15 +08:00
|
|
|
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
|
2021-02-02 09:59:42 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
void importPrimNode(Node *node, MlirBlock appendToBlock);
|
|
|
|
void importKernelCall(Node *node, MlirBlock appendToBlock);
|
|
|
|
|
|
|
|
MlirBlock createBlockFor(Block *jitBlock);
|
|
|
|
void mapValue(Value *jitValue, MlirValue value);
|
|
|
|
void mapResults(Node *node, MlirOperation operation);
|
|
|
|
MlirValue lookupMappedValue(Value *jitValue);
|
|
|
|
std::vector<MlirValue> lookupMappedValues(c10::ArrayRef<Value *> values);
|
|
|
|
|
|
|
|
MlirContext context;
|
|
|
|
std::unordered_map<Value *, MlirValue> valueMap;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|
|
|
TypeMapper typeMapper(context);
|
|
|
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
|
|
|
auto kind = node->kind();
|
2021-03-11 08:55:29 +08:00
|
|
|
|
|
|
|
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName) {
|
|
|
|
MlirOperation operation =
|
|
|
|
createMlirOperationAtEnd(appendToBlock, opName, loc,
|
|
|
|
getMlirTypesFromValues(loc, node->outputs()),
|
|
|
|
lookupMappedValues(node->inputs()));
|
|
|
|
mapResults(node, operation);
|
|
|
|
};
|
|
|
|
auto createAndMapNodeWithAttribute = [&](Node *node,
|
|
|
|
const std::string &opName,
|
|
|
|
const std::string &attrName,
|
|
|
|
MlirAttribute attr) {
|
|
|
|
MlirOperation operation =
|
|
|
|
createMlirOperationAtEnd(appendToBlock, opName, loc,
|
|
|
|
getMlirTypesFromValues(loc, node->outputs()),
|
|
|
|
lookupMappedValues(node->inputs()),
|
|
|
|
toMlirNamedAttribute(attrName.c_str(), attr));
|
|
|
|
mapResults(node, operation);
|
|
|
|
};
|
|
|
|
switch (kind) {
|
|
|
|
case c10::prim::TupleIndex:
|
|
|
|
case c10::prim::TupleUnpack:
|
|
|
|
case c10::prim::ListUnpack:
|
|
|
|
case c10::prim::dtype:
|
2021-03-11 09:25:39 +08:00
|
|
|
case c10::prim::device:
|
2021-03-11 08:55:29 +08:00
|
|
|
case c10::prim::unchecked_cast:
|
|
|
|
case c10::prim::Uninitialized:
|
|
|
|
case c10::prim::RaiseException:
|
|
|
|
case c10::prim::Print:
|
|
|
|
case c10::prim::NumToTensor: {
|
|
|
|
createAndMapTrivialNode(node,
|
|
|
|
"torch.prim." + std::string(kind.toUnqualString()));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case c10::prim::ListConstruct: {
|
|
|
|
createAndMapTrivialNode(node, "basicpy.build_list");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case c10::prim::TupleConstruct: {
|
|
|
|
createAndMapTrivialNode(node, "basicpy.build_tuple");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case c10::prim::GetAttr:
|
2021-03-11 09:25:39 +08:00
|
|
|
case c10::prim::SetAttr: {
|
2021-03-11 08:55:29 +08:00
|
|
|
createAndMapNodeWithAttribute(
|
|
|
|
node, "torch.prim." + std::string(kind.toUnqualString()), "name",
|
|
|
|
importAttribute(loc, node, c10::attr::name));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-02-02 09:59:42 +08:00
|
|
|
if (kind == c10::prim::Constant) {
|
|
|
|
auto output = node->output();
|
|
|
|
MlirOperation op;
|
|
|
|
OpBuilder builder(context);
|
|
|
|
if (output->type()->cast<c10::NoneType>()) {
|
|
|
|
op = builder.createNoneConstant(loc);
|
|
|
|
} else if (output->type()->cast<c10::BoolType>()) {
|
|
|
|
op = builder.createBoolConstant(
|
|
|
|
loc, static_cast<bool>(node->i(c10::attr::value)));
|
2021-02-06 06:54:04 +08:00
|
|
|
} else if (output->type()->cast<c10::StringType>()) {
|
|
|
|
// TODO: Are TorchScript strings bytes or str technically?
|
|
|
|
// For now, model it as bytes to avoid pledging more than we currently
|
|
|
|
// model (e.g. no unicode, etc.).
|
|
|
|
op = builder.createBytesConstant(loc, node->s(c10::attr::value));
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
|
|
|
torch::jit::Function *function = functionType->function();
|
|
|
|
const std::string &symName = function->qualname().qualifiedName();
|
|
|
|
op = createMlirOperation(
|
|
|
|
"std.constant", loc,
|
2021-03-02 09:24:15 +08:00
|
|
|
getFunctionTypeFromSchema(context, function->getSchema()),
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
toMlirNamedAttribute(
|
|
|
|
"value",
|
|
|
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
2021-02-02 09:59:42 +08:00
|
|
|
} else {
|
|
|
|
MlirAttribute valueAttr = importAttribute(loc, node, c10::attr::value);
|
|
|
|
op = builder.createStdConstant(loc, valueAttr);
|
|
|
|
}
|
|
|
|
mlirBlockAppendOwnedOperation(appendToBlock, op);
|
|
|
|
mapResults(node, op);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2021-03-02 07:00:32 +08:00
|
|
|
if (kind == c10::prim::Loop) {
|
2021-03-02 09:24:15 +08:00
|
|
|
std::vector<MlirType> resultTypes =
|
|
|
|
getMlirTypesFromValues(loc, node->outputs());
|
2021-03-02 07:00:32 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
2021-03-02 09:24:15 +08:00
|
|
|
appendToBlock, "torch.prim.Loop", loc, resultTypes,
|
|
|
|
lookupMappedValues(node->inputs().slice(0, 2)),
|
|
|
|
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
|
|
|
|
loc, appendToBlock),
|
|
|
|
mlirRegionCreate());
|
2021-03-02 07:00:32 +08:00
|
|
|
mapResults(node, operation);
|
2021-03-02 09:24:15 +08:00
|
|
|
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
|
|
|
|
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
|
|
|
resultTypes.begin(), resultTypes.end());
|
|
|
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
|
|
|
MlirBlock appendToBlock) {
|
|
|
|
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
|
|
|
|
derefineValues(yieldedValues,
|
|
|
|
terminatorOperandTypes, loc,
|
|
|
|
appendToBlock));
|
|
|
|
};
|
2021-03-02 07:00:32 +08:00
|
|
|
mlirRegionAppendOwnedBlock(
|
|
|
|
mlirOperationGetRegion(operation, 0),
|
2021-03-02 09:24:15 +08:00
|
|
|
importBlock(node->blocks()[0], createTerminator));
|
2021-03-02 07:00:32 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2021-02-02 09:59:42 +08:00
|
|
|
if (kind == c10::prim::If) {
|
|
|
|
// TorchScript will already have an explicit op to determine truthiness. So
|
|
|
|
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
|
|
|
MlirOperation pred = createMlirOperationAtEnd(
|
|
|
|
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
|
|
|
|
lookupMappedValue(node->input()));
|
2021-03-02 09:24:15 +08:00
|
|
|
std::vector<MlirType> resultTypes =
|
|
|
|
getMlirTypesFromValues(loc, node->outputs());
|
2021-02-02 09:59:42 +08:00
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
2021-03-02 09:24:15 +08:00
|
|
|
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
2021-02-02 09:59:42 +08:00
|
|
|
mapResults(node, operation);
|
2021-03-02 09:24:15 +08:00
|
|
|
auto createTerminator =
|
|
|
|
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
|
|
|
|
createMlirOperationAtEnd(
|
|
|
|
appendToBlock, "scf.yield", loc,
|
|
|
|
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
|
|
|
};
|
|
|
|
mlirRegionAppendOwnedBlock(
|
|
|
|
mlirOperationGetRegion(operation, 0),
|
|
|
|
importBlock(node->blocks()[0], createTerminator));
|
|
|
|
mlirRegionAppendOwnedBlock(
|
|
|
|
mlirOperationGetRegion(operation, 1),
|
|
|
|
importBlock(node->blocks()[1], createTerminator));
|
2021-02-02 09:59:42 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2021-03-11 09:25:39 +08:00
|
|
|
if (kind == c10::prim::CallMethod) {
|
|
|
|
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
|
|
|
auto methodName = node->s(c10::attr::name);
|
|
|
|
torch::jit::Function *function = classType->findMethod(methodName);
|
|
|
|
torch::jit::Block *calleeEntryBlock = function->graph()->block();
|
|
|
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
|
|
|
return typeMapper.mapFromTorchType(loc, v->type());
|
|
|
|
});
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
appendToBlock, "torch.prim.CallMethod", loc,
|
|
|
|
getMlirTypesFromValues(loc, node->outputs()),
|
|
|
|
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
|
|
|
|
appendToBlock),
|
|
|
|
toMlirNamedAttribute("name",
|
|
|
|
importAttribute(loc, node, c10::attr::name)));
|
|
|
|
mapResults(node, operation);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
if (kind == c10::prim::CallFunction) {
|
2021-03-02 09:24:15 +08:00
|
|
|
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
|
|
|
torch::jit::Block *calleeEntryBlock =
|
|
|
|
functionType->function()->graph()->block();
|
|
|
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
|
|
|
return typeMapper.mapFromTorchType(loc, v->type());
|
|
|
|
});
|
|
|
|
MlirOperation operation = createMlirOperationAtEnd(
|
|
|
|
appendToBlock, "std.call_indirect", loc,
|
|
|
|
getMlirTypesFromValues(loc, node->outputs()),
|
|
|
|
lookupMappedValue(node->input(0)),
|
|
|
|
derefineValues(lookupMappedValues(node->inputs().slice(1)),
|
|
|
|
expectedTypes, loc, appendToBlock));
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
mapResults(node, operation);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2021-02-02 09:59:42 +08:00
|
|
|
// Unhandled.
|
|
|
|
{
|
|
|
|
std::stringstream msg;
|
|
|
|
msg << "unhandled prim operation: ";
|
|
|
|
node->print(msg, 0, nullptr);
|
|
|
|
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
|
|
|
|
throw mlir_diagnostic_emitted();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void NodeImporter::importKernelCall(Node *node, MlirBlock appendToBlock) {
|
|
|
|
TypeMapper typeMapper(context);
|
|
|
|
MlirLocation loc = getMlirLocationFromNode(context, node);
|
|
|
|
KernelCallBuilder kcb(context, loc, node->kind().toQualString(),
|
|
|
|
node->schema());
|
|
|
|
for (MlirValue value : lookupMappedValues(node->inputs())) {
|
|
|
|
kcb.addOperand(value);
|
|
|
|
}
|
|
|
|
for (MlirType type : getMlirTypesFromValues(loc, node->outputs())) {
|
|
|
|
kcb.addResultType(type);
|
|
|
|
}
|
|
|
|
MlirOperation op = kcb.create();
|
|
|
|
mlirBlockAppendOwnedOperation(appendToBlock, op);
|
|
|
|
mapResults(node, op);
|
|
|
|
}
|
|
|
|
|
|
|
|
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|
|
|
if (node->kind().ns() == c10::namespaces::prim) {
|
|
|
|
importPrimNode(node, appendToBlock);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
if (node->maybeSchema()) {
|
|
|
|
importKernelCall(node, appendToBlock);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
{
|
|
|
|
std::stringstream msg;
|
|
|
|
msg << "unhandled: generic operation: ";
|
|
|
|
node->print(msg, 0, nullptr);
|
|
|
|
mlirEmitError(getMlirLocationFromNode(context, node), msg.str().c_str());
|
|
|
|
throw mlir_diagnostic_emitted();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
2021-03-02 09:24:15 +08:00
|
|
|
CreateTerminatorFn createTerminator) {
|
2021-02-02 09:59:42 +08:00
|
|
|
MlirBlock block = createBlockFor(jitBlock);
|
|
|
|
for (Node *node : jitBlock->nodes()) {
|
|
|
|
importNode(node, block);
|
|
|
|
}
|
|
|
|
Node *returnNode = jitBlock->return_node();
|
2021-03-02 09:24:15 +08:00
|
|
|
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
2021-02-02 09:59:42 +08:00
|
|
|
return block;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
|
|
|
Node *paramNode = jitBlock->param_node();
|
|
|
|
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
|
|
|
std::vector<MlirType> blockArgTypes =
|
|
|
|
getMlirTypesFromValues(loc, paramNode->outputs());
|
|
|
|
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data());
|
|
|
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
|
|
|
Value *jitValue = paramNode->outputs()[i];
|
|
|
|
MlirValue value = mlirBlockGetArgument(block, i);
|
|
|
|
mapValue(jitValue, value);
|
|
|
|
}
|
|
|
|
return block;
|
|
|
|
}
|
|
|
|
|
|
|
|
void NodeImporter::mapValue(Value *jitValue, MlirValue value) {
|
|
|
|
auto it = valueMap.find(jitValue);
|
|
|
|
(void)it;
|
|
|
|
assert(it == valueMap.end() && "jitValue has already been mapped");
|
|
|
|
valueMap[jitValue] = value;
|
|
|
|
}
|
|
|
|
void NodeImporter::mapResults(Node *node, MlirOperation operation) {
|
|
|
|
assert(node->outputs().size() ==
|
|
|
|
(size_t)mlirOperationGetNumResults(operation));
|
|
|
|
for (int i = 0, e = node->outputs().size(); i < e; i++) {
|
|
|
|
mapValue(node->outputs()[i], mlirOperationGetResult(operation, i));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MlirValue NodeImporter::lookupMappedValue(Value *jitValue) {
|
|
|
|
auto it = valueMap.find(jitValue);
|
|
|
|
assert(it != valueMap.end() &&
|
|
|
|
"trying to get mapping for jitValue that is not mapped yet!");
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
std::vector<MlirValue>
|
|
|
|
NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
|
|
|
std::vector<MlirValue> ret;
|
|
|
|
for (Value *value : values) {
|
|
|
|
ret.push_back(lookupMappedValue(value));
|
|
|
|
}
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
2021-03-02 09:24:15 +08:00
|
|
|
CreateTerminatorFn createTerminator) {
|
2021-02-02 09:59:42 +08:00
|
|
|
NodeImporter importer(context);
|
2021-03-02 09:24:15 +08:00
|
|
|
return importer.importBlock(jitBlock, createTerminator);
|
2021-02-02 09:59:42 +08:00
|
|
|
}
|