Add TorchScript graph importer.

* Does not handle all features yet but should conservatively fail on unsupported things.
* Location tracking is still somewhat mismatched between what TorchScript and MLIR do. Likely need a better heuristic for tracking locations from defs for nodes that do not carry location.
* Sets the ground-work for a specialized/generic split but only implements the generic side.
* Had some evidence that this requires a recent bump of PT nightly (within the last month) to pick up pybind11 2.6, which includes some cross-module symbol fixes (vs the previously sync'd version). No source changes, but older versions fail to cast function types at runtime.
pull/124/head
Stella Laurenzo 2020-11-20 17:03:23 -08:00
parent 2021d3609e
commit 78a3c90758
17 changed files with 712 additions and 91 deletions

View File

@ -13,6 +13,13 @@ include_directories(BEFORE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
# TODO: Fix implicit ordering. If PyTorch was build against an external
# pybind11, then it will not be in the above search path and must be
# resolved here, in the hope that it is the same one we were configured
# with (which it should be if installed via pip). This is really fragile,
# though, causing cast failures at runtime if we get it wrong. Come up with
# a way to strengthen this.
${pybind11_INCLUDE_DIR}
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
@ -28,7 +35,11 @@ target_link_libraries(NPCOMPTorchMLIRExt
# NPCOMP shared library.
NPCOMP
)
add_dependencies(NPCOMPTorchMLIRExt
# Uses of the torch_mlir extension also require the npcomp extension to
# be built.
NPCOMPNativePyExt
)
set_target_properties(NPCOMPTorchMLIRExt PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/python
OUTPUT_NAME _torch_mlir

View File

@ -5,12 +5,20 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
# TODO: Fix implicit ordering. If PyTorch was build against an external
# pybind11, then it will not be in the above search path and must be
# resolved here, in the hope that it is the same one we were configured
# with (which it should be if installed via pip). This is really fragile,
# though, causing cast failures at runtime if we get it wrong. Come up with
# a way to strengthen this.
${pybind11_INCLUDE_DIR}
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(npcomp_torch_builder_bindings
acap_dispatch.cpp
debug.cpp
func_builder.cpp
graph_importer.cpp
module_builder.cpp
python_bindings.cpp
)

View File

@ -42,63 +42,17 @@ static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
static c10::DispatchKey kAcapGradDispatchKey =
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
AcapController::KernelCallBuilder::KernelCallBuilder(
AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder(
AcapController &parent, MlirContext context, MlirLocation loc,
const c10::OperatorHandle &opHandle,
llvm::Optional<std::string> overrideKernelName)
: parent(parent), context(context), loc(loc), opHandle(opHandle),
state("torch.kernel_call", loc) {
(void)this->context; // Preserve for future.
const std::string &kernelName =
overrideKernelName ? *overrideKernelName : opHandle.operator_name().name;
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
"kernelName",
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
addSchemaAttrs();
}
: KernelCallBuilder(context, loc,
overrideKernelName ? *overrideKernelName
: opHandle.operator_name().name,
opHandle.schema()),
parent(parent), opHandle(opHandle) {}
void AcapController::KernelCallBuilder::addSchemaAttrs() {
// Map the op schema to the kernel_call attributes:
// sigArgTypes
// sigRetTypes
// sigIsVararg
// sigIsVarret
// sigIsMutable
const c10::FunctionSchema &schema = opHandle.schema();
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
attrs.push_back(mlirNamedAttributeGet(
"sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable())));
attrs.push_back(mlirNamedAttributeGet(
"sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg())));
attrs.push_back(mlirNamedAttributeGet(
"sigIsVarret", mlirBoolAttrGet(context, schema.is_varret())));
// Arg types.
llvm::SmallVector<MlirAttribute, 4> args;
for (auto &arg : schema.arguments()) {
const std::string &typeStr = arg.type()->str();
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
}
attrs.push_back(mlirNamedAttributeGet(
"sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data())));
// Return types.
llvm::SmallVector<MlirAttribute, 4> returns;
for (auto &ret : schema.returns()) {
const std::string &typeStr = ret.type()->str();
returns.push_back(
mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
}
attrs.push_back(mlirNamedAttributeGet(
"sigRetTypes",
mlirArrayAttrGet(context, returns.size(), returns.data())));
// Add attrs to op.
mlirOperationStateAddAttributes(state, attrs.size(), attrs.data());
}
void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
void AcapController::TracedKernelCallBuilder::addOperand(const IValue &value) {
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
if (mlirValueIsNull(mlirValue)) {
std::stringstream out;
@ -107,10 +61,10 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
<< value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
mlirOperationStateAddOperands(state, 1, &mlirValue);
KernelCallBuilder::addOperand(mlirValue);
}
void AcapController::KernelCallBuilder::addResult(const IValue &value) {
void AcapController::TracedKernelCallBuilder::addResult(const IValue &value) {
MlirType resultType = parent.mapIValueToMlirType(loc, value);
if (mlirTypeIsNull(resultType)) {
std::stringstream out;
@ -122,12 +76,11 @@ void AcapController::KernelCallBuilder::addResult(const IValue &value) {
if (value.isTensor()) {
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
}
mlirOperationStateAddResults(state, 1, &resultType);
KernelCallBuilder::addResultType(resultType);
}
MlirOperation AcapController::KernelCallBuilder::create() {
// Create operation.
MlirOperation op = state.createOperation();
MlirOperation AcapController::TracedKernelCallBuilder::create() {
MlirOperation op = KernelCallBuilder::create();
parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
// Map result tensors.
@ -264,7 +217,7 @@ at::Tensor AcapController::convolutionKernel(
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
std::string kernelName{"aten::convolution"};
KernelCallBuilder callBuilder{*current, context, loc, *opHandle};
TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle};
callBuilder.addOperand(IValue(input));
callBuilder.addOperand(IValue(weight));
@ -344,8 +297,8 @@ AcapController::mklConvolutionBackward(
""};
auto emitOpHandle = dispatcher.findOp(emitOpName);
assert(emitOpHandle && "could not find convolution_backward_overrideable op");
KernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle,
kernelName};
TracedKernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle,
kernelName};
callBuilder.addOperand(IValue(grad_output));
callBuilder.addOperand(IValue(input));
@ -398,7 +351,7 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
KernelCallBuilder callBuilder{*current, context, loc, *opHandle};
TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle};
callBuilder.addOperand(IValue(self));
callBuilder.addOperand(IValue(src));
@ -462,7 +415,7 @@ void AcapController::fallbackKernelImpl(
MlirContext context = funcBuilder->getContext();
MlirLocation loc = getCurrentLocation();
auto kernelName = schema.name();
KernelCallBuilder callBuilder{*this, context, loc, opHandle};
TracedKernelCallBuilder callBuilder{*this, context, loc, opHandle};
// Map arguments to operands.
// This must be accumulated into the OperationState prior to re-dispatch
@ -553,7 +506,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
const IValue &ival) {
if (ival.isScalar()) {
return typeMapper.mapScalarType(ival.toScalar().type());
return typeMapper.mapFromTorchScalarType(ival.toScalar().type());
}
if (ival.isTensor()) {
return typeMapper.forwardTensorToType(ival.toTensor());
@ -601,7 +554,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
// Basicpy bool type.
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
} else {
elementType = typeMapper.mapScalarType(tensor.scalar_type());
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
}
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
tensor.sizes().end());

View File

@ -80,9 +80,9 @@ public:
private:
/// Builds a kernel call step by step.
class KernelCallBuilder {
class TracedKernelCallBuilder : private KernelCallBuilder {
public:
KernelCallBuilder(
TracedKernelCallBuilder(
AcapController &parent, MlirContext context, MlirLocation loc,
const c10::OperatorHandle &opHandle,
llvm::Optional<std::string> overrideKernelName = llvm::None);
@ -91,12 +91,8 @@ private:
MlirOperation create();
private:
void addSchemaAttrs();
AcapController &parent;
MlirContext context;
MlirLocation loc;
const c10::OperatorHandle &opHandle;
OperationStateHolder state;
int resultCount = 0;
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
};

View File

@ -7,6 +7,7 @@
#include "func_builder.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
#include "npcomp-c/Types.h"
@ -22,7 +23,90 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
return s.createOperation();
}
MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) {
KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc,
llvm::StringRef kernelName,
const c10::FunctionSchema &schema)
: context(context), loc(loc), state("torch.kernel_call", loc),
kernelName(kernelName), schema(schema) {
(void)this->context; // Preserve for future.
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
"kernelName",
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
addSchemaAttrs();
}
void KernelCallBuilder::addSchemaAttrs() {
// Map the op schema to the kernel_call attributes:
// sigArgTypes
// sigRetTypes
// sigIsVararg
// sigIsVarret
// sigIsMutable
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
attrs.push_back(mlirNamedAttributeGet(
"sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable())));
attrs.push_back(mlirNamedAttributeGet(
"sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg())));
attrs.push_back(mlirNamedAttributeGet(
"sigIsVarret", mlirBoolAttrGet(context, schema.is_varret())));
// Arg types.
llvm::SmallVector<MlirAttribute, 4> args;
for (auto &arg : schema.arguments()) {
const std::string &typeStr = arg.type()->str();
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
}
attrs.push_back(mlirNamedAttributeGet(
"sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data())));
// Return types.
llvm::SmallVector<MlirAttribute, 4> returns;
for (auto &ret : schema.returns()) {
const std::string &typeStr = ret.type()->str();
returns.push_back(
mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
}
attrs.push_back(mlirNamedAttributeGet(
"sigRetTypes",
mlirArrayAttrGet(context, returns.size(), returns.data())));
// Add attrs to op.
mlirOperationStateAddAttributes(state, attrs.size(), attrs.data());
}
void KernelCallBuilder::addOperand(MlirValue operand) {
mlirOperationStateAddOperands(state, 1, &operand);
}
void KernelCallBuilder::addResultType(MlirType resultType) {
mlirOperationStateAddResults(state, 1, &resultType);
}
MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
throw std::invalid_argument(message.str());
}
return type;
}
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
mlirEmitError(loc, message.str().c_str());
}
return type;
}
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
using c10::ScalarType;
switch (scalarType) {
case ScalarType::Byte:
@ -49,10 +133,49 @@ MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) {
return mlirBF16TypeGet(context);
case ScalarType::Half:
return mlirF16TypeGet(context);
default: {
return {nullptr};
}
}
}
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
using c10::TypeKind;
auto kind = torchType->kind();
switch (kind) {
case TypeKind::TensorType: {
auto tensorType = torchType->cast<c10::TensorType>();
// Element type.
MlirType elementType;
if (tensorType->scalarType()) {
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
if (mlirTypeIsNull(elementType))
return {nullptr};
} else {
elementType = npcompAnyDtypeTypeGet(context);
}
// Sizes.
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return npcompNdArrayTypeGetUnranked(elementType);
}
// Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes();
llvm::SmallVector<int64_t, 4> dims;
dims.resize(*sizes.rank());
for (size_t i = 0; i < dims.size(); ++i) {
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
}
default: {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
throw std::invalid_argument(message.str());
message << "unable to map Torch type " << torchType << " to MLIR type";
mlirEmitError(loc, message.str().c_str());
return {nullptr};
}
}
}
@ -64,7 +187,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
return npcompNoneTypeGet(context);
}
MlirType elementType = mapScalarType(tensor.scalar_type());
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
// TODO: Decide when it is necessary to take strides into account. Right now,
// just erase them and let the compiler decide.
@ -73,9 +196,10 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
}
std::unique_ptr<FuncBuilder>
FuncBuilder::createFunction(MlirContext context, MlirLocation location,
llvm::StringRef name,
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
MlirLocation location, llvm::StringRef name,
llvm::SmallVectorImpl<MlirType> &inputTypes) {
auto context = mlirLocationGetContext(location);
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
// (this is fragile and reveals details that are not guaranteed).
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
@ -102,6 +226,7 @@ FuncBuilder::createFunction(MlirContext context, MlirLocation location,
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
inserter(funcOp);
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
}

View File

@ -13,6 +13,7 @@
#include "llvm/ADT/StringRef.h"
#include <ATen/Tensor.h>
#include <ATen/core/function_schema.h>
namespace torch_mlir {
@ -51,21 +52,30 @@ public:
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Throws std::invalid_argument on failure.
MlirType mapScalarType(c10::ScalarType scalarType);
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
/// Gets a corresponding MlirType for the forward component of a tensor.
/// Throws std::invalid_argument on failure.
MlirType forwardTensorToType(at::Tensor tensor);
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Returns a null type on failure and emits a diagnostic.
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
private:
/// Maps from a scalar type and returns null if no match (no other error
/// reporting).
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
MlirContext context;
};
/// Wraps an MlirBlock under construction, primarily tracking the terminator
/// and supporting manipulation of it. The terminator may be null if it has
/// not yet been constructed, although, for entry blocks, we always construct
/// the function with an appropriate return terminator (which can be changed
/// later).
/// not yet been constructed.
class BlockBuilder {
public:
BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn)
@ -86,15 +96,39 @@ private:
bool isReturn;
};
/// Builds a kernel call step by step.
class KernelCallBuilder {
public:
KernelCallBuilder(MlirContext context, MlirLocation loc,
llvm::StringRef kernelName,
const c10::FunctionSchema &schema);
void addOperand(MlirValue operand);
void addResultType(MlirType resultType);
MlirOperation create();
protected:
MlirContext context;
MlirLocation loc;
private:
void addSchemaAttrs();
OperationStateHolder state;
llvm::StringRef kernelName;
const c10::FunctionSchema &schema;
};
/// Wraps a 'func' MlirOperation and provides facilities for constructing
/// IR from some stream of Torch operations.
class FuncBuilder {
public:
/// Callback for inserting a function.
using Inserter = std::function<void(MlirOperation funcOp)>;
/// Creates a new func op with the given characteristics. The created
/// operation is not attached. The caller must either destroy it or add it
/// to a parent.
static std::unique_ptr<FuncBuilder>
createFunction(MlirContext context, MlirLocation location,
createFunction(Inserter &inserter, MlirLocation location,
llvm::StringRef name,
llvm::SmallVectorImpl<MlirType> &inputTypes);

View File

@ -0,0 +1,303 @@
//===- graph_importer.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
namespace py = pybind11;
using namespace torch_mlir;
//------------------------------------------------------------------------------
// GraphImporter::NodeScope implementation
//------------------------------------------------------------------------------
// A scope of Graph Value * to corresponding MlirValue. Scopes nest
// region-wise. Note that in PyTorch, the thing called 'Block' is analagous
// to a capturing MLIR region.
class GraphImporter::NodeScope {
public:
NodeScope() = default;
NodeScope(NodeScope *prev) : prev(prev) {}
void bindValue(torch::jit::Value *torchValue, MlirValue value);
MlirValue findValue(torch::jit::Value *torchValue);
MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue);
private:
llvm::DenseMap<torch::jit::Value *, MlirValue> valueMap;
NodeScope *prev = nullptr;
};
void GraphImporter::NodeScope::bindValue(torch::jit::Value *torchValue,
MlirValue value) {
assert(valueMap.count(torchValue) == 0 && "duplicate torch Value bind");
valueMap[torchValue] = value;
}
MlirValue GraphImporter::NodeScope::findValue(torch::jit::Value *torchValue) {
auto foundIt = valueMap.find(torchValue);
if (foundIt == valueMap.end()) {
if (prev)
return prev->findValue(torchValue);
else
return {nullptr};
}
return foundIt->second;
}
MlirValue
GraphImporter::NodeScope::findRequiredValue(MlirLocation loc,
torch::jit::Value *torchValue) {
MlirValue value = findValue(torchValue);
if (mlirValueIsNull(value)) {
std::stringstream msg;
msg << "internal error: unmapped torch value: %" << torchValue->debugName();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
return value;
}
//------------------------------------------------------------------------------
// GraphImporter::NodeImporter implementation
//------------------------------------------------------------------------------
/// Helper class to import a torch::jit::Node into an MLIR function.
/// This class primarily exists to eliminate the need for large lists of
/// carried arguments related to doing the import.
class GraphImporter::NodeImporter {
public:
NodeImporter(torch::jit::Node *node, GraphImporter &parent,
FuncBuilder *funcBuilder, MlirBlock block, MlirOperation ip,
NodeScope *scope);
void importNode();
void importReturnOp();
private:
MlirContext context() { return parent.context(); }
void importPrimNode();
MlirAttribute importValueAttribute();
torch::jit::Node *node;
GraphImporter &parent;
FuncBuilder *funcBuilder;
MlirBlock block;
MlirOperation ip;
NodeScope *scope;
MlirLocation loc;
};
GraphImporter::NodeImporter::NodeImporter(torch::jit::Node *node,
GraphImporter &parent,
FuncBuilder *funcBuilder,
MlirBlock block, MlirOperation ip,
NodeScope *scope)
: node(node), parent(parent), funcBuilder(funcBuilder), block(block),
ip(ip), scope(scope) {
loc = parent.extractCallstackLoc(node);
}
void GraphImporter::NodeImporter::importNode() {
// Prim namespace handled specially.
auto kind = node->kind();
if (kind.ns() == c10::namespaces::prim) {
importPrimNode();
return;
}
// Generic import.
auto funcSchema = node->maybeSchema();
if (funcSchema) {
KernelCallBuilder kcb(context(), loc, kind.toQualString(), *funcSchema);
for (auto *input : node->inputs()) {
kcb.addOperand(scope->findRequiredValue(loc, input));
}
for (auto *output : node->outputs()) {
MlirType type =
parent.type_mapper().mapFromTorchType(loc, output->type());
if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted();
}
kcb.addResultType(type);
}
MlirOperation op = kcb.create();
mlirBlockInsertOwnedOperationBefore(block, ip, op);
// Map results.
for (auto it : llvm::enumerate(node->outputs())) {
scope->bindValue(it.value(), mlirOperationGetResult(op, it.index()));
}
return;
}
// No soup for you. Not exactly sure when this can happen.
{
std::stringstream msg;
msg << "unhandled: generic operation " << kind.toDisplayString();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
void GraphImporter::NodeImporter::importReturnOp() {
OperationStateHolder s("std.return", loc);
llvm::SmallVector<MlirValue, 4> returnsValues;
for (auto *input : node->inputs()) {
returnsValues.push_back(scope->findRequiredValue(loc, input));
}
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
mlirBlockInsertOwnedOperationBefore(block, ip, s.createOperation());
}
void GraphImporter::NodeImporter::importPrimNode() {
auto kind = node->kind();
if (kind == c10::prim::Constant) {
auto output = node->output();
MlirAttribute valueAttr = importValueAttribute();
MlirValue constValue = funcBuilder->getGeneralConstant(loc, valueAttr);
scope->bindValue(output, constValue);
return;
}
// Unhandled.
{
std::stringstream msg;
msg << "unhandled: prim operation " << kind.toDisplayString();
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
MlirAttribute GraphImporter::NodeImporter::importValueAttribute() {
using torch::jit::AttributeKind;
auto s = c10::attr::value;
auto kind = node->kindOf(s);
switch (kind) {
case AttributeKind::i:
// TODO: This should be a signed int once we have a constant op that can
// do that.
return mlirIntegerAttrGet(mlirIntegerTypeGet(context(), 64), node->i(s));
break;
case AttributeKind::f:
return mlirFloatAttrDoubleGet(context(), mlirF64TypeGet(context()),
node->f(s));
break;
default: {
std::stringstream msg;
msg << "unhandled: value attribute kind " << toString(kind);
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
}
//------------------------------------------------------------------------------
// GraphImporter implementation
//------------------------------------------------------------------------------
GraphImporter::GraphImporter(std::shared_ptr<torch::jit::Graph> graph,
MlirMappingOptions mappingOptions)
: graph(std::move(graph)), mappingOptions(std::move(mappingOptions)) {}
std::shared_ptr<GraphImporter> GraphImporter::forPythonJitFunc(
torch::jit::Function *function,
GraphImporter::MlirMappingOptions mappingOptions) {
// Disallow an attempt to compile a native function.
if (!function->isGraphFunction()) {
throw std::invalid_argument(
"Expected a torch.jit.ScriptFunction with a graph");
}
auto graph = function->graph();
if (!mappingOptions.genericFuncName) {
mappingOptions.genericFuncName = function->name() + "$generic";
}
if (!mappingOptions.funcName) {
mappingOptions.funcName = function->name() + "$generic";
}
return std::make_shared<GraphImporter>(graph, std::move(mappingOptions));
}
void GraphImporter::initialize() {
defaultLoc = mlirLocationUnknownGet(context());
// There is not a callstack associated with the graph so, try to grab
// a location from the first node that has one as a better than nothing
// thing.
// TODO: This doesn't actually seem to be working. Investigate when more
// examples are built out.
for (auto *node : graph->nodes()) {
MlirLocation nodeLoc = extractCallstackLoc(node, /*useDefault=*/false);
if (nodeLoc.ptr) {
defaultLoc = nodeLoc;
break;
}
}
// Map inputs.
MlirLocation inputLoc = extractCallstackLoc(graph->param_node());
for (const auto &input : graph->inputs()) {
MlirType t = type_mapper().mapFromTorchType(inputLoc, input->type());
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("could not convert function input type");
genericFuncArgTypes.push_back(t);
}
// Map outputs.
MlirLocation outputLoc = extractCallstackLoc(graph->return_node());
for (const auto &output : graph->outputs()) {
MlirType t = type_mapper().mapFromTorchType(outputLoc, output->type());
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("could not convert function output type");
genericFuncReturnTypes.push_back(t);
}
}
void GraphImporter::importGenericFunc() {
auto funcBuilder = FuncBuilder::createFunction(
mappingOptions.inserter, defaultLoc, *mappingOptions.genericFuncName,
genericFuncArgTypes);
funcBuilder->rewriteFuncReturnTypes(genericFuncReturnTypes);
MlirBlock entryBlock = funcBuilder->getEntryBlock();
// Bind inputs.
NodeScope scope;
for (const auto &it : llvm::enumerate(graph->inputs())) {
MlirValue value = mlirBlockGetArgument(entryBlock, it.index());
scope.bindValue(it.value(), value);
}
// Walk body nodes.
for (auto *node : graph->nodes()) {
NodeImporter importer{
node, *this, funcBuilder.get(), entryBlock, /*ip=*/{nullptr}, &scope};
importer.importNode();
}
// Map the output node to a return.
auto *outputNode = graph->return_node();
NodeImporter returnImporter{outputNode, *this,
funcBuilder.get(), entryBlock,
/*ip=*/{nullptr}, &scope};
returnImporter.importReturnOp();
}
MlirLocation GraphImporter::extractCallstackLoc(torch::jit::Node *node,
bool useDefault) {
auto flc = node->sourceRange().file_line_col();
if (flc) {
const std::string &file = std::get<0>(*flc);
int line = std::get<1>(*flc);
int col = std::get<2>(*flc);
return mlirLocationFileLineColGet(context(), file.c_str(), line, col);
}
return useDefault ? defaultLoc : MlirLocation{nullptr};
}

View File

@ -0,0 +1,91 @@
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch::jit::Graph instances (and structures
/// surrounding them such as modules and methods).
///
/// In torch terminology, a Graph is a function. Later in the compiler, we may
/// specialize multiple versions of it.
///
/// Since graph functions typically have enough annotations for the most
/// generic form of every type (i.e. Tensor, List, etc), and since we often
/// want to multi-version specific specializations, we take the approach of
/// generating a '$generic' suffixed function at that level and then generate
/// the actual named function with using a 'numpy.generic_call' op to invoke
/// the generic function with metadata controlling how it is legal to
/// specialize. This leaves the process of inlining and expanding the
/// specializations to compiler passes.
class GraphImporter : public std::enable_shared_from_this<GraphImporter> {
public:
/// Options for mapping Graph concepts to MLIR. In addition to things like
/// names and type mappings, this includes various policy options such as
/// when to import globals as constants vs shared arrays, etc.
struct MlirMappingOptions {
MlirContext context;
llvm::Optional<std::string> genericFuncName;
llvm::Optional<std::string> funcName;
TypeMapper &typeMapper;
FuncBuilder::Inserter &inserter;
};
/// Construct an importer.
GraphImporter(std::shared_ptr<torch::jit::Graph> graph,
MlirMappingOptions mappingOptions);
/// Helper to create a graph importer from a traced/scripted python function.
/// If the funcName of the mapping options is not set, it is set from the
/// function name. It is the responsibility of the caller to ensure that the
/// funcObj and associated graph outlives this instance.
static std::shared_ptr<GraphImporter>
forPythonJitFunc(torch::jit::Function *function,
MlirMappingOptions mappingOptions);
/// Initialize for import. This is separate from the constructor purely for
/// ergonomics and must be called post-construction. Initialization activities
/// that throw go here.
void initialize();
/// Imports the generic function into the module.
void importGenericFunc();
private:
class NodeScope;
class NodeImporter;
MlirContext context() { return mappingOptions.context; }
TypeMapper &type_mapper() { return mappingOptions.typeMapper; }
MlirLocation extractCallstackLoc(torch::jit::Node *node,
bool useDefault = true);
std::shared_ptr<torch::jit::Graph> graph;
MlirMappingOptions mappingOptions;
/// Default function location, to be used when a more specific is not
/// available.
MlirLocation defaultLoc;
/// Argument and return types for the generic func.
llvm::SmallVector<MlirType, 4> genericFuncArgTypes;
llvm::SmallVector<MlirType, 4> genericFuncReturnTypes;
};
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H

View File

@ -7,6 +7,8 @@
#include "module_builder.h"
#include "graph_importer.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
@ -78,11 +80,9 @@ ModuleBuilder::startCaptureFunction(std::string &name,
}
// TODO: Extract a traceback and use in place of unknownLoc.
auto inserter = createInserter();
auto funcBuilder =
FuncBuilder::createFunction(context, unknownLoc, name, inputTypes);
mlirBlockInsertOwnedOperationBefore(getBodyBlock(), terminator,
funcBuilder->getFuncOp());
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
// Map block arguments.
MlirBlock entryBlock = funcBuilder->getEntryBlock();
assert(mlirBlockGetNumArguments(entryBlock) ==
@ -95,6 +95,28 @@ ModuleBuilder::startCaptureFunction(std::string &name,
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
}
void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
auto inserter = createInserter();
GraphImporter::MlirMappingOptions mappingOptions{
context,
llvm::None, // genericFuncName (default to auto)
llvm::None, // funcName (default to auto)
typeMapper, inserter,
};
auto graphImporter = GraphImporter::forPythonJitFunc(
function.function_, std::move(mappingOptions));
graphImporter->initialize();
graphImporter->importGenericFunc();
}
FuncBuilder::Inserter ModuleBuilder::createInserter() {
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
return [=](MlirOperation op) {
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
};
}
MlirBlock ModuleBuilder::getBodyBlock() {
MlirOperation moduleOp = mlirModuleGetOperation(module);
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
@ -106,5 +128,6 @@ void ModuleBuilder::bind(py::module &m) {
.def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>());
py::keep_alive<0, 1>())
.def("import_function", &ModuleBuilder::importFunction);
}

View File

@ -16,6 +16,8 @@
#include "llvm/ADT/SmallVector.h"
#include <ATen/Tensor.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
@ -32,11 +34,16 @@ public:
pybind11::object getModuleObj() { return moduleObj; }
// Starts a device-capture based function.
// TODO: Add inputs.
std::shared_ptr<AcapController>
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
// Imports a traced function. Note that the python type
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
// Just a bit of naming cruft.
void importFunction(torch::jit::StrongFunctionPtr function);
private:
FuncBuilder::Inserter createInserter();
MlirBlock getBodyBlock();
// Capture references to the python-owned context and module. Ownership

View File

@ -140,4 +140,3 @@ void torch_mlir::InitBuilderBindings(py::module &m) {
ModuleBuilder::bind(m);
}

View File

@ -14,4 +14,15 @@
#include <torch/csrc/utils/pybind.h>
namespace torch_mlir {
/// Thrown on failure when details are in MLIR emitted diagnostics.
class mlir_diagnostic_emitted : public std::runtime_error {
public:
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
};
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H

View File

@ -2,6 +2,8 @@
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir

View File

@ -0,0 +1,25 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
@torch.jit.script
def add3(t0, t1, t2):
return t0 + t1 + t2
mb = torch_mlir.ModuleBuilder()
mb.import_function(add3)
# Verify without debug info.
# CHECK-LABEL: func @add3$generic
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[C1:.*]] = constant 1 : i64
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
mb.module.operation.print()
print()

View File

@ -0,0 +1,26 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | FileCheck %s
@torch.jit.script
def add3(t0, t1, t2):
intermediate = t0 + t1
final = intermediate + t2
return final
mb = torch_mlir.ModuleBuilder()
mb.import_function(add3)
# Verify again with debug info present. Just checking that it makes it in there.
# CHECK-LABEL: func @add3$generic
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py
mb.module.operation.print(enable_debug_info=True)
print()

View File

@ -64,6 +64,9 @@ MlirType npcompListTypeGet(MlirContext context);
/** Checks whether the given type is an NdArray type. */
int npcompTypeIsANdArray(MlirType t);
/** Gets a numpy.NdArray type that is unranked. */
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType);
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
* unknown. */
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,

View File

@ -70,6 +70,10 @@ int npcompTypeIsANdArray(MlirType t) {
return unwrap(t).isa<Numpy::NdArrayType>();
}
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) {
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
}
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType) {
llvm::ArrayRef<int64_t> shapeArray(shape, rank);