mirror of https://github.com/llvm/torch-mlir
Add TorchScript graph importer.
* Does not handle all features yet but should conservatively fail on unsupported things. * Location tracking is still somewhat mismatched between what TorchScript and MLIR do. Likely need a better heuristic for tracking locations from defs for nodes that do not carry location. * Sets the ground-work for a specialized/generic split but only implements the generic side. * Had some evidence that this requires a recent bump of PT nightly (within the last month) to pick up pybind11 2.6, which includes some cross-module symbol fixes (vs the previously sync'd version). No source changes, but older versions fail to cast function types at runtime.pull/124/head
parent
2021d3609e
commit
78a3c90758
|
@ -13,6 +13,13 @@ include_directories(BEFORE
|
|||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
# TODO: Fix implicit ordering. If PyTorch was build against an external
|
||||
# pybind11, then it will not be in the above search path and must be
|
||||
# resolved here, in the hope that it is the same one we were configured
|
||||
# with (which it should be if installed via pip). This is really fragile,
|
||||
# though, causing cast failures at runtime if we get it wrong. Come up with
|
||||
# a way to strengthen this.
|
||||
${pybind11_INCLUDE_DIR}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
|
@ -28,7 +35,11 @@ target_link_libraries(NPCOMPTorchMLIRExt
|
|||
# NPCOMP shared library.
|
||||
NPCOMP
|
||||
)
|
||||
|
||||
add_dependencies(NPCOMPTorchMLIRExt
|
||||
# Uses of the torch_mlir extension also require the npcomp extension to
|
||||
# be built.
|
||||
NPCOMPNativePyExt
|
||||
)
|
||||
set_target_properties(NPCOMPTorchMLIRExt PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/python
|
||||
OUTPUT_NAME _torch_mlir
|
||||
|
|
|
@ -5,12 +5,20 @@ include_directories(
|
|||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
# TODO: Fix implicit ordering. If PyTorch was build against an external
|
||||
# pybind11, then it will not be in the above search path and must be
|
||||
# resolved here, in the hope that it is the same one we were configured
|
||||
# with (which it should be if installed via pip). This is really fragile,
|
||||
# though, causing cast failures at runtime if we get it wrong. Come up with
|
||||
# a way to strengthen this.
|
||||
${pybind11_INCLUDE_DIR}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
add_library(npcomp_torch_builder_bindings
|
||||
acap_dispatch.cpp
|
||||
debug.cpp
|
||||
func_builder.cpp
|
||||
graph_importer.cpp
|
||||
module_builder.cpp
|
||||
python_bindings.cpp
|
||||
)
|
||||
|
|
|
@ -42,63 +42,17 @@ static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
|||
static c10::DispatchKey kAcapGradDispatchKey =
|
||||
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
|
||||
|
||||
AcapController::KernelCallBuilder::KernelCallBuilder(
|
||||
AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder(
|
||||
AcapController &parent, MlirContext context, MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle,
|
||||
llvm::Optional<std::string> overrideKernelName)
|
||||
: parent(parent), context(context), loc(loc), opHandle(opHandle),
|
||||
state("torch.kernel_call", loc) {
|
||||
(void)this->context; // Preserve for future.
|
||||
const std::string &kernelName =
|
||||
overrideKernelName ? *overrideKernelName : opHandle.operator_name().name;
|
||||
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
||||
"kernelName",
|
||||
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
|
||||
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
|
||||
addSchemaAttrs();
|
||||
}
|
||||
: KernelCallBuilder(context, loc,
|
||||
overrideKernelName ? *overrideKernelName
|
||||
: opHandle.operator_name().name,
|
||||
opHandle.schema()),
|
||||
parent(parent), opHandle(opHandle) {}
|
||||
|
||||
void AcapController::KernelCallBuilder::addSchemaAttrs() {
|
||||
// Map the op schema to the kernel_call attributes:
|
||||
// sigArgTypes
|
||||
// sigRetTypes
|
||||
// sigIsVararg
|
||||
// sigIsVarret
|
||||
// sigIsMutable
|
||||
const c10::FunctionSchema &schema = opHandle.schema();
|
||||
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsVarret", mlirBoolAttrGet(context, schema.is_varret())));
|
||||
|
||||
// Arg types.
|
||||
llvm::SmallVector<MlirAttribute, 4> args;
|
||||
for (auto &arg : schema.arguments()) {
|
||||
const std::string &typeStr = arg.type()->str();
|
||||
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||
}
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data())));
|
||||
|
||||
// Return types.
|
||||
llvm::SmallVector<MlirAttribute, 4> returns;
|
||||
for (auto &ret : schema.returns()) {
|
||||
const std::string &typeStr = ret.type()->str();
|
||||
returns.push_back(
|
||||
mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||
}
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigRetTypes",
|
||||
mlirArrayAttrGet(context, returns.size(), returns.data())));
|
||||
|
||||
// Add attrs to op.
|
||||
mlirOperationStateAddAttributes(state, attrs.size(), attrs.data());
|
||||
}
|
||||
|
||||
void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
|
||||
void AcapController::TracedKernelCallBuilder::addOperand(const IValue &value) {
|
||||
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
|
||||
if (mlirValueIsNull(mlirValue)) {
|
||||
std::stringstream out;
|
||||
|
@ -107,10 +61,10 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
|
|||
<< value.tagKind() << "): " << value;
|
||||
throw std::invalid_argument(out.str());
|
||||
}
|
||||
mlirOperationStateAddOperands(state, 1, &mlirValue);
|
||||
KernelCallBuilder::addOperand(mlirValue);
|
||||
}
|
||||
|
||||
void AcapController::KernelCallBuilder::addResult(const IValue &value) {
|
||||
void AcapController::TracedKernelCallBuilder::addResult(const IValue &value) {
|
||||
MlirType resultType = parent.mapIValueToMlirType(loc, value);
|
||||
if (mlirTypeIsNull(resultType)) {
|
||||
std::stringstream out;
|
||||
|
@ -122,12 +76,11 @@ void AcapController::KernelCallBuilder::addResult(const IValue &value) {
|
|||
if (value.isTensor()) {
|
||||
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
|
||||
}
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
KernelCallBuilder::addResultType(resultType);
|
||||
}
|
||||
|
||||
MlirOperation AcapController::KernelCallBuilder::create() {
|
||||
// Create operation.
|
||||
MlirOperation op = state.createOperation();
|
||||
MlirOperation AcapController::TracedKernelCallBuilder::create() {
|
||||
MlirOperation op = KernelCallBuilder::create();
|
||||
parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
|
||||
|
||||
// Map result tensors.
|
||||
|
@ -264,7 +217,7 @@ at::Tensor AcapController::convolutionKernel(
|
|||
MlirContext context = current->funcBuilder->getContext();
|
||||
MlirLocation loc = current->getCurrentLocation();
|
||||
std::string kernelName{"aten::convolution"};
|
||||
KernelCallBuilder callBuilder{*current, context, loc, *opHandle};
|
||||
TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle};
|
||||
|
||||
callBuilder.addOperand(IValue(input));
|
||||
callBuilder.addOperand(IValue(weight));
|
||||
|
@ -344,8 +297,8 @@ AcapController::mklConvolutionBackward(
|
|||
""};
|
||||
auto emitOpHandle = dispatcher.findOp(emitOpName);
|
||||
assert(emitOpHandle && "could not find convolution_backward_overrideable op");
|
||||
KernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle,
|
||||
kernelName};
|
||||
TracedKernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle,
|
||||
kernelName};
|
||||
|
||||
callBuilder.addOperand(IValue(grad_output));
|
||||
callBuilder.addOperand(IValue(input));
|
||||
|
@ -398,7 +351,7 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
|
|||
|
||||
MlirContext context = current->funcBuilder->getContext();
|
||||
MlirLocation loc = current->getCurrentLocation();
|
||||
KernelCallBuilder callBuilder{*current, context, loc, *opHandle};
|
||||
TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle};
|
||||
|
||||
callBuilder.addOperand(IValue(self));
|
||||
callBuilder.addOperand(IValue(src));
|
||||
|
@ -462,7 +415,7 @@ void AcapController::fallbackKernelImpl(
|
|||
MlirContext context = funcBuilder->getContext();
|
||||
MlirLocation loc = getCurrentLocation();
|
||||
auto kernelName = schema.name();
|
||||
KernelCallBuilder callBuilder{*this, context, loc, opHandle};
|
||||
TracedKernelCallBuilder callBuilder{*this, context, loc, opHandle};
|
||||
|
||||
// Map arguments to operands.
|
||||
// This must be accumulated into the OperationState prior to re-dispatch
|
||||
|
@ -553,7 +506,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||
const IValue &ival) {
|
||||
if (ival.isScalar()) {
|
||||
return typeMapper.mapScalarType(ival.toScalar().type());
|
||||
return typeMapper.mapFromTorchScalarType(ival.toScalar().type());
|
||||
}
|
||||
if (ival.isTensor()) {
|
||||
return typeMapper.forwardTensorToType(ival.toTensor());
|
||||
|
@ -601,7 +554,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
// Basicpy bool type.
|
||||
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
} else {
|
||||
elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
||||
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
||||
}
|
||||
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
||||
tensor.sizes().end());
|
||||
|
|
|
@ -80,9 +80,9 @@ public:
|
|||
|
||||
private:
|
||||
/// Builds a kernel call step by step.
|
||||
class KernelCallBuilder {
|
||||
class TracedKernelCallBuilder : private KernelCallBuilder {
|
||||
public:
|
||||
KernelCallBuilder(
|
||||
TracedKernelCallBuilder(
|
||||
AcapController &parent, MlirContext context, MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle,
|
||||
llvm::Optional<std::string> overrideKernelName = llvm::None);
|
||||
|
@ -91,12 +91,8 @@ private:
|
|||
MlirOperation create();
|
||||
|
||||
private:
|
||||
void addSchemaAttrs();
|
||||
AcapController &parent;
|
||||
MlirContext context;
|
||||
MlirLocation loc;
|
||||
const c10::OperatorHandle &opHandle;
|
||||
OperationStateHolder state;
|
||||
int resultCount = 0;
|
||||
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
|
||||
};
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
|
@ -22,7 +23,90 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
|||
return s.createOperation();
|
||||
}
|
||||
|
||||
MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) {
|
||||
KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc,
|
||||
llvm::StringRef kernelName,
|
||||
const c10::FunctionSchema &schema)
|
||||
: context(context), loc(loc), state("torch.kernel_call", loc),
|
||||
kernelName(kernelName), schema(schema) {
|
||||
(void)this->context; // Preserve for future.
|
||||
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
||||
"kernelName",
|
||||
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
|
||||
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
|
||||
addSchemaAttrs();
|
||||
}
|
||||
|
||||
void KernelCallBuilder::addSchemaAttrs() {
|
||||
// Map the op schema to the kernel_call attributes:
|
||||
// sigArgTypes
|
||||
// sigRetTypes
|
||||
// sigIsVararg
|
||||
// sigIsVarret
|
||||
// sigIsMutable
|
||||
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsMutable", mlirBoolAttrGet(context, schema.is_mutable())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsVararg", mlirBoolAttrGet(context, schema.is_vararg())));
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigIsVarret", mlirBoolAttrGet(context, schema.is_varret())));
|
||||
|
||||
// Arg types.
|
||||
llvm::SmallVector<MlirAttribute, 4> args;
|
||||
for (auto &arg : schema.arguments()) {
|
||||
const std::string &typeStr = arg.type()->str();
|
||||
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||
}
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigArgTypes", mlirArrayAttrGet(context, args.size(), args.data())));
|
||||
|
||||
// Return types.
|
||||
llvm::SmallVector<MlirAttribute, 4> returns;
|
||||
for (auto &ret : schema.returns()) {
|
||||
const std::string &typeStr = ret.type()->str();
|
||||
returns.push_back(
|
||||
mlirStringAttrGet(context, typeStr.size(), typeStr.data()));
|
||||
}
|
||||
attrs.push_back(mlirNamedAttributeGet(
|
||||
"sigRetTypes",
|
||||
mlirArrayAttrGet(context, returns.size(), returns.data())));
|
||||
|
||||
// Add attrs to op.
|
||||
mlirOperationStateAddAttributes(state, attrs.size(), attrs.data());
|
||||
}
|
||||
|
||||
void KernelCallBuilder::addOperand(MlirValue operand) {
|
||||
mlirOperationStateAddOperands(state, 1, &operand);
|
||||
}
|
||||
|
||||
void KernelCallBuilder::addResultType(MlirType resultType) {
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
}
|
||||
|
||||
MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
|
||||
|
||||
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||
auto type = rawMapFromTorchScalarType(scalarType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
throw std::invalid_argument(message.str());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
|
||||
c10::ScalarType scalarType) {
|
||||
auto type = rawMapFromTorchScalarType(scalarType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
mlirEmitError(loc, message.str().c_str());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||
using c10::ScalarType;
|
||||
switch (scalarType) {
|
||||
case ScalarType::Byte:
|
||||
|
@ -49,10 +133,49 @@ MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) {
|
|||
return mlirBF16TypeGet(context);
|
||||
case ScalarType::Half:
|
||||
return mlirF16TypeGet(context);
|
||||
default: {
|
||||
return {nullptr};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType) {
|
||||
using c10::TypeKind;
|
||||
auto kind = torchType->kind();
|
||||
switch (kind) {
|
||||
case TypeKind::TensorType: {
|
||||
auto tensorType = torchType->cast<c10::TensorType>();
|
||||
// Element type.
|
||||
MlirType elementType;
|
||||
if (tensorType->scalarType()) {
|
||||
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
||||
if (mlirTypeIsNull(elementType))
|
||||
return {nullptr};
|
||||
} else {
|
||||
elementType = npcompAnyDtypeTypeGet(context);
|
||||
}
|
||||
// Sizes.
|
||||
auto &sizes = tensorType->symbolic_sizes();
|
||||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return npcompNdArrayTypeGetUnranked(elementType);
|
||||
}
|
||||
// Ranked with possibly dynamic dims.
|
||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||
llvm::SmallVector<int64_t, 4> dims;
|
||||
dims.resize(*sizes.rank());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
throw std::invalid_argument(message.str());
|
||||
message << "unable to map Torch type " << torchType << " to MLIR type";
|
||||
mlirEmitError(loc, message.str().c_str());
|
||||
return {nullptr};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +187,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
return npcompNoneTypeGet(context);
|
||||
}
|
||||
|
||||
MlirType elementType = mapScalarType(tensor.scalar_type());
|
||||
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
||||
// TODO: Decide when it is necessary to take strides into account. Right now,
|
||||
// just erase them and let the compiler decide.
|
||||
|
||||
|
@ -73,9 +196,10 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
}
|
||||
|
||||
std::unique_ptr<FuncBuilder>
|
||||
FuncBuilder::createFunction(MlirContext context, MlirLocation location,
|
||||
llvm::StringRef name,
|
||||
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||
MlirLocation location, llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes) {
|
||||
auto context = mlirLocationGetContext(location);
|
||||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||
// (this is fragile and reveals details that are not guaranteed).
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||
|
@ -102,6 +226,7 @@ FuncBuilder::createFunction(MlirContext context, MlirLocation location,
|
|||
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
|
||||
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
|
||||
|
||||
inserter(funcOp);
|
||||
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
|
||||
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
|
@ -51,21 +52,30 @@ public:
|
|||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType mapScalarType(c10::ScalarType scalarType);
|
||||
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
|
||||
|
||||
/// Gets a corresponding MlirType for the forward component of a tensor.
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType forwardTensorToType(at::Tensor tensor);
|
||||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// Returns a null type on failure and emits a diagnostic.
|
||||
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
||||
|
||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||
/// on failure and emits a diagnostic.
|
||||
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
|
||||
|
||||
private:
|
||||
/// Maps from a scalar type and returns null if no match (no other error
|
||||
/// reporting).
|
||||
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
|
||||
MlirContext context;
|
||||
};
|
||||
|
||||
/// Wraps an MlirBlock under construction, primarily tracking the terminator
|
||||
/// and supporting manipulation of it. The terminator may be null if it has
|
||||
/// not yet been constructed, although, for entry blocks, we always construct
|
||||
/// the function with an appropriate return terminator (which can be changed
|
||||
/// later).
|
||||
/// not yet been constructed.
|
||||
class BlockBuilder {
|
||||
public:
|
||||
BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn)
|
||||
|
@ -86,15 +96,39 @@ private:
|
|||
bool isReturn;
|
||||
};
|
||||
|
||||
/// Builds a kernel call step by step.
|
||||
class KernelCallBuilder {
|
||||
public:
|
||||
KernelCallBuilder(MlirContext context, MlirLocation loc,
|
||||
llvm::StringRef kernelName,
|
||||
const c10::FunctionSchema &schema);
|
||||
void addOperand(MlirValue operand);
|
||||
void addResultType(MlirType resultType);
|
||||
MlirOperation create();
|
||||
|
||||
protected:
|
||||
MlirContext context;
|
||||
MlirLocation loc;
|
||||
|
||||
private:
|
||||
void addSchemaAttrs();
|
||||
OperationStateHolder state;
|
||||
llvm::StringRef kernelName;
|
||||
const c10::FunctionSchema &schema;
|
||||
};
|
||||
|
||||
/// Wraps a 'func' MlirOperation and provides facilities for constructing
|
||||
/// IR from some stream of Torch operations.
|
||||
class FuncBuilder {
|
||||
public:
|
||||
/// Callback for inserting a function.
|
||||
using Inserter = std::function<void(MlirOperation funcOp)>;
|
||||
|
||||
/// Creates a new func op with the given characteristics. The created
|
||||
/// operation is not attached. The caller must either destroy it or add it
|
||||
/// to a parent.
|
||||
static std::unique_ptr<FuncBuilder>
|
||||
createFunction(MlirContext context, MlirLocation location,
|
||||
createFunction(Inserter &inserter, MlirLocation location,
|
||||
llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes);
|
||||
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
//===- graph_importer.cpp -------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "graph_importer.h"
|
||||
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// GraphImporter::NodeScope implementation
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// A scope of Graph Value * to corresponding MlirValue. Scopes nest
|
||||
// region-wise. Note that in PyTorch, the thing called 'Block' is analagous
|
||||
// to a capturing MLIR region.
|
||||
class GraphImporter::NodeScope {
|
||||
public:
|
||||
NodeScope() = default;
|
||||
NodeScope(NodeScope *prev) : prev(prev) {}
|
||||
|
||||
void bindValue(torch::jit::Value *torchValue, MlirValue value);
|
||||
MlirValue findValue(torch::jit::Value *torchValue);
|
||||
MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue);
|
||||
|
||||
private:
|
||||
llvm::DenseMap<torch::jit::Value *, MlirValue> valueMap;
|
||||
NodeScope *prev = nullptr;
|
||||
};
|
||||
|
||||
void GraphImporter::NodeScope::bindValue(torch::jit::Value *torchValue,
|
||||
MlirValue value) {
|
||||
assert(valueMap.count(torchValue) == 0 && "duplicate torch Value bind");
|
||||
valueMap[torchValue] = value;
|
||||
}
|
||||
|
||||
MlirValue GraphImporter::NodeScope::findValue(torch::jit::Value *torchValue) {
|
||||
auto foundIt = valueMap.find(torchValue);
|
||||
if (foundIt == valueMap.end()) {
|
||||
if (prev)
|
||||
return prev->findValue(torchValue);
|
||||
else
|
||||
return {nullptr};
|
||||
}
|
||||
return foundIt->second;
|
||||
}
|
||||
|
||||
MlirValue
|
||||
GraphImporter::NodeScope::findRequiredValue(MlirLocation loc,
|
||||
torch::jit::Value *torchValue) {
|
||||
MlirValue value = findValue(torchValue);
|
||||
if (mlirValueIsNull(value)) {
|
||||
std::stringstream msg;
|
||||
msg << "internal error: unmapped torch value: %" << torchValue->debugName();
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// GraphImporter::NodeImporter implementation
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
/// Helper class to import a torch::jit::Node into an MLIR function.
|
||||
/// This class primarily exists to eliminate the need for large lists of
|
||||
/// carried arguments related to doing the import.
|
||||
class GraphImporter::NodeImporter {
|
||||
public:
|
||||
NodeImporter(torch::jit::Node *node, GraphImporter &parent,
|
||||
FuncBuilder *funcBuilder, MlirBlock block, MlirOperation ip,
|
||||
NodeScope *scope);
|
||||
|
||||
void importNode();
|
||||
void importReturnOp();
|
||||
|
||||
private:
|
||||
MlirContext context() { return parent.context(); }
|
||||
void importPrimNode();
|
||||
MlirAttribute importValueAttribute();
|
||||
|
||||
torch::jit::Node *node;
|
||||
GraphImporter &parent;
|
||||
FuncBuilder *funcBuilder;
|
||||
MlirBlock block;
|
||||
MlirOperation ip;
|
||||
NodeScope *scope;
|
||||
MlirLocation loc;
|
||||
};
|
||||
|
||||
GraphImporter::NodeImporter::NodeImporter(torch::jit::Node *node,
|
||||
GraphImporter &parent,
|
||||
FuncBuilder *funcBuilder,
|
||||
MlirBlock block, MlirOperation ip,
|
||||
NodeScope *scope)
|
||||
: node(node), parent(parent), funcBuilder(funcBuilder), block(block),
|
||||
ip(ip), scope(scope) {
|
||||
loc = parent.extractCallstackLoc(node);
|
||||
}
|
||||
|
||||
void GraphImporter::NodeImporter::importNode() {
|
||||
// Prim namespace handled specially.
|
||||
auto kind = node->kind();
|
||||
if (kind.ns() == c10::namespaces::prim) {
|
||||
importPrimNode();
|
||||
return;
|
||||
}
|
||||
|
||||
// Generic import.
|
||||
auto funcSchema = node->maybeSchema();
|
||||
if (funcSchema) {
|
||||
KernelCallBuilder kcb(context(), loc, kind.toQualString(), *funcSchema);
|
||||
for (auto *input : node->inputs()) {
|
||||
kcb.addOperand(scope->findRequiredValue(loc, input));
|
||||
}
|
||||
for (auto *output : node->outputs()) {
|
||||
MlirType type =
|
||||
parent.type_mapper().mapFromTorchType(loc, output->type());
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
kcb.addResultType(type);
|
||||
}
|
||||
MlirOperation op = kcb.create();
|
||||
mlirBlockInsertOwnedOperationBefore(block, ip, op);
|
||||
|
||||
// Map results.
|
||||
for (auto it : llvm::enumerate(node->outputs())) {
|
||||
scope->bindValue(it.value(), mlirOperationGetResult(op, it.index()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// No soup for you. Not exactly sure when this can happen.
|
||||
{
|
||||
std::stringstream msg;
|
||||
msg << "unhandled: generic operation " << kind.toDisplayString();
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
}
|
||||
|
||||
void GraphImporter::NodeImporter::importReturnOp() {
|
||||
OperationStateHolder s("std.return", loc);
|
||||
llvm::SmallVector<MlirValue, 4> returnsValues;
|
||||
for (auto *input : node->inputs()) {
|
||||
returnsValues.push_back(scope->findRequiredValue(loc, input));
|
||||
}
|
||||
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
|
||||
mlirBlockInsertOwnedOperationBefore(block, ip, s.createOperation());
|
||||
}
|
||||
|
||||
void GraphImporter::NodeImporter::importPrimNode() {
|
||||
auto kind = node->kind();
|
||||
if (kind == c10::prim::Constant) {
|
||||
auto output = node->output();
|
||||
MlirAttribute valueAttr = importValueAttribute();
|
||||
MlirValue constValue = funcBuilder->getGeneralConstant(loc, valueAttr);
|
||||
scope->bindValue(output, constValue);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
msg << "unhandled: prim operation " << kind.toDisplayString();
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
}
|
||||
|
||||
MlirAttribute GraphImporter::NodeImporter::importValueAttribute() {
|
||||
using torch::jit::AttributeKind;
|
||||
auto s = c10::attr::value;
|
||||
auto kind = node->kindOf(s);
|
||||
switch (kind) {
|
||||
case AttributeKind::i:
|
||||
// TODO: This should be a signed int once we have a constant op that can
|
||||
// do that.
|
||||
return mlirIntegerAttrGet(mlirIntegerTypeGet(context(), 64), node->i(s));
|
||||
break;
|
||||
case AttributeKind::f:
|
||||
return mlirFloatAttrDoubleGet(context(), mlirF64TypeGet(context()),
|
||||
node->f(s));
|
||||
break;
|
||||
|
||||
default: {
|
||||
std::stringstream msg;
|
||||
msg << "unhandled: value attribute kind " << toString(kind);
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// GraphImporter implementation
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
GraphImporter::GraphImporter(std::shared_ptr<torch::jit::Graph> graph,
|
||||
MlirMappingOptions mappingOptions)
|
||||
: graph(std::move(graph)), mappingOptions(std::move(mappingOptions)) {}
|
||||
|
||||
std::shared_ptr<GraphImporter> GraphImporter::forPythonJitFunc(
|
||||
torch::jit::Function *function,
|
||||
GraphImporter::MlirMappingOptions mappingOptions) {
|
||||
// Disallow an attempt to compile a native function.
|
||||
if (!function->isGraphFunction()) {
|
||||
throw std::invalid_argument(
|
||||
"Expected a torch.jit.ScriptFunction with a graph");
|
||||
}
|
||||
auto graph = function->graph();
|
||||
if (!mappingOptions.genericFuncName) {
|
||||
mappingOptions.genericFuncName = function->name() + "$generic";
|
||||
}
|
||||
if (!mappingOptions.funcName) {
|
||||
mappingOptions.funcName = function->name() + "$generic";
|
||||
}
|
||||
return std::make_shared<GraphImporter>(graph, std::move(mappingOptions));
|
||||
}
|
||||
|
||||
void GraphImporter::initialize() {
|
||||
defaultLoc = mlirLocationUnknownGet(context());
|
||||
// There is not a callstack associated with the graph so, try to grab
|
||||
// a location from the first node that has one as a better than nothing
|
||||
// thing.
|
||||
// TODO: This doesn't actually seem to be working. Investigate when more
|
||||
// examples are built out.
|
||||
for (auto *node : graph->nodes()) {
|
||||
MlirLocation nodeLoc = extractCallstackLoc(node, /*useDefault=*/false);
|
||||
if (nodeLoc.ptr) {
|
||||
defaultLoc = nodeLoc;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Map inputs.
|
||||
MlirLocation inputLoc = extractCallstackLoc(graph->param_node());
|
||||
for (const auto &input : graph->inputs()) {
|
||||
MlirType t = type_mapper().mapFromTorchType(inputLoc, input->type());
|
||||
if (mlirTypeIsNull(t))
|
||||
throw mlir_diagnostic_emitted("could not convert function input type");
|
||||
genericFuncArgTypes.push_back(t);
|
||||
}
|
||||
|
||||
// Map outputs.
|
||||
MlirLocation outputLoc = extractCallstackLoc(graph->return_node());
|
||||
for (const auto &output : graph->outputs()) {
|
||||
MlirType t = type_mapper().mapFromTorchType(outputLoc, output->type());
|
||||
if (mlirTypeIsNull(t))
|
||||
throw mlir_diagnostic_emitted("could not convert function output type");
|
||||
genericFuncReturnTypes.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphImporter::importGenericFunc() {
|
||||
auto funcBuilder = FuncBuilder::createFunction(
|
||||
mappingOptions.inserter, defaultLoc, *mappingOptions.genericFuncName,
|
||||
genericFuncArgTypes);
|
||||
funcBuilder->rewriteFuncReturnTypes(genericFuncReturnTypes);
|
||||
MlirBlock entryBlock = funcBuilder->getEntryBlock();
|
||||
|
||||
// Bind inputs.
|
||||
NodeScope scope;
|
||||
for (const auto &it : llvm::enumerate(graph->inputs())) {
|
||||
MlirValue value = mlirBlockGetArgument(entryBlock, it.index());
|
||||
scope.bindValue(it.value(), value);
|
||||
}
|
||||
|
||||
// Walk body nodes.
|
||||
for (auto *node : graph->nodes()) {
|
||||
NodeImporter importer{
|
||||
node, *this, funcBuilder.get(), entryBlock, /*ip=*/{nullptr}, &scope};
|
||||
importer.importNode();
|
||||
}
|
||||
|
||||
// Map the output node to a return.
|
||||
auto *outputNode = graph->return_node();
|
||||
NodeImporter returnImporter{outputNode, *this,
|
||||
funcBuilder.get(), entryBlock,
|
||||
/*ip=*/{nullptr}, &scope};
|
||||
returnImporter.importReturnOp();
|
||||
}
|
||||
|
||||
MlirLocation GraphImporter::extractCallstackLoc(torch::jit::Node *node,
|
||||
bool useDefault) {
|
||||
auto flc = node->sourceRange().file_line_col();
|
||||
if (flc) {
|
||||
const std::string &file = std::get<0>(*flc);
|
||||
int line = std::get<1>(*flc);
|
||||
int col = std::get<2>(*flc);
|
||||
return mlirLocationFileLineColGet(context(), file.c_str(), line, col);
|
||||
}
|
||||
|
||||
return useDefault ? defaultLoc : MlirLocation{nullptr};
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Main entry-point for importing torch::jit::Graph instances (and structures
|
||||
/// surrounding them such as modules and methods).
|
||||
///
|
||||
/// In torch terminology, a Graph is a function. Later in the compiler, we may
|
||||
/// specialize multiple versions of it.
|
||||
///
|
||||
/// Since graph functions typically have enough annotations for the most
|
||||
/// generic form of every type (i.e. Tensor, List, etc), and since we often
|
||||
/// want to multi-version specific specializations, we take the approach of
|
||||
/// generating a '$generic' suffixed function at that level and then generate
|
||||
/// the actual named function with using a 'numpy.generic_call' op to invoke
|
||||
/// the generic function with metadata controlling how it is legal to
|
||||
/// specialize. This leaves the process of inlining and expanding the
|
||||
/// specializations to compiler passes.
|
||||
class GraphImporter : public std::enable_shared_from_this<GraphImporter> {
|
||||
public:
|
||||
/// Options for mapping Graph concepts to MLIR. In addition to things like
|
||||
/// names and type mappings, this includes various policy options such as
|
||||
/// when to import globals as constants vs shared arrays, etc.
|
||||
struct MlirMappingOptions {
|
||||
MlirContext context;
|
||||
llvm::Optional<std::string> genericFuncName;
|
||||
llvm::Optional<std::string> funcName;
|
||||
TypeMapper &typeMapper;
|
||||
FuncBuilder::Inserter &inserter;
|
||||
};
|
||||
/// Construct an importer.
|
||||
GraphImporter(std::shared_ptr<torch::jit::Graph> graph,
|
||||
MlirMappingOptions mappingOptions);
|
||||
|
||||
/// Helper to create a graph importer from a traced/scripted python function.
|
||||
/// If the funcName of the mapping options is not set, it is set from the
|
||||
/// function name. It is the responsibility of the caller to ensure that the
|
||||
/// funcObj and associated graph outlives this instance.
|
||||
static std::shared_ptr<GraphImporter>
|
||||
forPythonJitFunc(torch::jit::Function *function,
|
||||
MlirMappingOptions mappingOptions);
|
||||
|
||||
/// Initialize for import. This is separate from the constructor purely for
|
||||
/// ergonomics and must be called post-construction. Initialization activities
|
||||
/// that throw go here.
|
||||
void initialize();
|
||||
|
||||
/// Imports the generic function into the module.
|
||||
void importGenericFunc();
|
||||
|
||||
private:
|
||||
class NodeScope;
|
||||
class NodeImporter;
|
||||
|
||||
MlirContext context() { return mappingOptions.context; }
|
||||
TypeMapper &type_mapper() { return mappingOptions.typeMapper; }
|
||||
MlirLocation extractCallstackLoc(torch::jit::Node *node,
|
||||
bool useDefault = true);
|
||||
std::shared_ptr<torch::jit::Graph> graph;
|
||||
MlirMappingOptions mappingOptions;
|
||||
|
||||
/// Default function location, to be used when a more specific is not
|
||||
/// available.
|
||||
MlirLocation defaultLoc;
|
||||
|
||||
/// Argument and return types for the generic func.
|
||||
llvm::SmallVector<MlirType, 4> genericFuncArgTypes;
|
||||
llvm::SmallVector<MlirType, 4> genericFuncReturnTypes;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#include "module_builder.h"
|
||||
|
||||
#include "graph_importer.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
|
@ -78,11 +80,9 @@ ModuleBuilder::startCaptureFunction(std::string &name,
|
|||
}
|
||||
|
||||
// TODO: Extract a traceback and use in place of unknownLoc.
|
||||
auto inserter = createInserter();
|
||||
auto funcBuilder =
|
||||
FuncBuilder::createFunction(context, unknownLoc, name, inputTypes);
|
||||
mlirBlockInsertOwnedOperationBefore(getBodyBlock(), terminator,
|
||||
funcBuilder->getFuncOp());
|
||||
|
||||
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
|
||||
// Map block arguments.
|
||||
MlirBlock entryBlock = funcBuilder->getEntryBlock();
|
||||
assert(mlirBlockGetNumArguments(entryBlock) ==
|
||||
|
@ -95,6 +95,28 @@ ModuleBuilder::startCaptureFunction(std::string &name,
|
|||
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
||||
}
|
||||
|
||||
void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
auto inserter = createInserter();
|
||||
GraphImporter::MlirMappingOptions mappingOptions{
|
||||
context,
|
||||
llvm::None, // genericFuncName (default to auto)
|
||||
llvm::None, // funcName (default to auto)
|
||||
typeMapper, inserter,
|
||||
};
|
||||
auto graphImporter = GraphImporter::forPythonJitFunc(
|
||||
function.function_, std::move(mappingOptions));
|
||||
graphImporter->initialize();
|
||||
graphImporter->importGenericFunc();
|
||||
}
|
||||
|
||||
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
return [=](MlirOperation op) {
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
|
||||
};
|
||||
}
|
||||
|
||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
||||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||
|
@ -106,5 +128,6 @@ void ModuleBuilder::bind(py::module &m) {
|
|||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
||||
py::keep_alive<0, 1>());
|
||||
py::keep_alive<0, 1>())
|
||||
.def("import_function", &ModuleBuilder::importFunction);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
|
@ -32,11 +34,16 @@ public:
|
|||
pybind11::object getModuleObj() { return moduleObj; }
|
||||
|
||||
// Starts a device-capture based function.
|
||||
// TODO: Add inputs.
|
||||
std::shared_ptr<AcapController>
|
||||
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
|
||||
|
||||
// Imports a traced function. Note that the python type
|
||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||
// Just a bit of naming cruft.
|
||||
void importFunction(torch::jit::StrongFunctionPtr function);
|
||||
|
||||
private:
|
||||
FuncBuilder::Inserter createInserter();
|
||||
MlirBlock getBodyBlock();
|
||||
|
||||
// Capture references to the python-owned context and module. Ownership
|
||||
|
|
|
@ -140,4 +140,3 @@ void torch_mlir::InitBuilderBindings(py::module &m) {
|
|||
|
||||
ModuleBuilder::bind(m);
|
||||
}
|
||||
|
||||
|
|
|
@ -14,4 +14,15 @@
|
|||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Thrown on failure when details are in MLIR emitted diagnostics.
|
||||
class mlir_diagnostic_emitted : public std::runtime_error {
|
||||
public:
|
||||
mlir_diagnostic_emitted(const char *what) : std::runtime_error(what) {}
|
||||
mlir_diagnostic_emitted() : std::runtime_error("see diagnostics") {}
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
return t0 + t1 + t2
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
mb.import_function(add3)
|
||||
|
||||
# Verify without debug info.
|
||||
# CHECK-LABEL: func @add3$generic
|
||||
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[C1:.*]] = constant 1 : i64
|
||||
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -0,0 +1,26 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
intermediate = t0 + t1
|
||||
final = intermediate + t2
|
||||
return final
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
mb.import_function(add3)
|
||||
|
||||
# Verify again with debug info present. Just checking that it makes it in there.
|
||||
# CHECK-LABEL: func @add3$generic
|
||||
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
mb.module.operation.print(enable_debug_info=True)
|
||||
print()
|
|
@ -64,6 +64,9 @@ MlirType npcompListTypeGet(MlirContext context);
|
|||
/** Checks whether the given type is an NdArray type. */
|
||||
int npcompTypeIsANdArray(MlirType t);
|
||||
|
||||
/** Gets a numpy.NdArray type that is unranked. */
|
||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType);
|
||||
|
||||
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
||||
* unknown. */
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
|
|
|
@ -70,6 +70,10 @@ int npcompTypeIsANdArray(MlirType t) {
|
|||
return unwrap(t).isa<Numpy::NdArrayType>();
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) {
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType) {
|
||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||
|
|
Loading…
Reference in New Issue