mirror of https://github.com/llvm/torch-mlir
Add initial TorchScript module importer
It turns out that this was easiest to structure as a general IValue importer, since torch module are just one of the possible IValue's. We import the IValue object graph in a braindead fashion into basicpy ops and a new `torch.nn_module` op that is used to model the attributes/methods of a torch::jit::Module IValue. See `Torch/ops.mlir` for an example, and also check out the .py import tests in `frontends/pytorch/test/module_import`. As part of this change, a few housekeeping tasks: - extract some helpers from graph_importer.cpp - more helpers around the C API - misc touchupspull/155/head
parent
0f6a65a1c5
commit
689b40c7a6
|
@ -16,7 +16,9 @@ add_library(NPCOMPTorchMLIRExt SHARED
|
||||||
builder/func_builder.cpp
|
builder/func_builder.cpp
|
||||||
builder/graph_importer.cpp
|
builder/graph_importer.cpp
|
||||||
builder/module_builder.cpp
|
builder/module_builder.cpp
|
||||||
|
builder/ivalue_importer.cpp
|
||||||
builder/python_bindings.cpp
|
builder/python_bindings.cpp
|
||||||
|
builder/torch_to_mlir_utils.cpp
|
||||||
init_python_bindings.cpp
|
init_python_bindings.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -529,80 +529,15 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
using at::ScalarType;
|
|
||||||
|
|
||||||
auto throwUnsupportedTensorError = [&]() {
|
|
||||||
std::stringstream msg;
|
|
||||||
msg << "Unsupported import tensor type: " << tensor;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
|
|
||||||
if (!tensor.is_contiguous())
|
|
||||||
tensor = tensor.contiguous();
|
|
||||||
|
|
||||||
// The flat number of bytes throws an exception for tensors that are not
|
|
||||||
// dense and accessible as such.
|
|
||||||
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
|
|
||||||
c10::Layout::Strided);
|
|
||||||
|
|
||||||
// Construct the ShapedType.
|
|
||||||
auto loc = getCurrentLocation();
|
auto loc = getCurrentLocation();
|
||||||
MlirType elementType;
|
MlirAttribute valueAttribute = converTensorToMlirElementsAttr(tensor, loc);
|
||||||
if (tensor.scalar_type() == ScalarType::Bool) {
|
|
||||||
// Bool is a special case. When used as an element type, it must be i1.
|
|
||||||
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
|
||||||
// Basicpy bool type.
|
|
||||||
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
|
||||||
} else {
|
|
||||||
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
|
||||||
}
|
|
||||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
|
||||||
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
|
||||||
shape.size(), shape.data(), elementType, loc);
|
|
||||||
if (mlirTypeIsNull(shapedType)) {
|
|
||||||
throwUnsupportedTensorError();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Import DenseElementsAttr data.
|
|
||||||
// TODO: Support bool tensors.
|
|
||||||
// TODO: More import formats in C-API.
|
|
||||||
MlirAttribute valueAttribute;
|
|
||||||
auto numElements = tensor.numel();
|
|
||||||
auto tensorData = tensor.data_ptr();
|
|
||||||
switch (tensor.scalar_type()) {
|
|
||||||
case ScalarType::Int:
|
|
||||||
valueAttribute = mlirDenseElementsAttrInt32Get(
|
|
||||||
shapedType, numElements, static_cast<const int32_t *>(tensorData));
|
|
||||||
break;
|
|
||||||
case ScalarType::Long:
|
|
||||||
valueAttribute = mlirDenseElementsAttrInt64Get(
|
|
||||||
shapedType, numElements, static_cast<const int64_t *>(tensorData));
|
|
||||||
break;
|
|
||||||
case ScalarType::Float:
|
|
||||||
valueAttribute = mlirDenseElementsAttrFloatGet(
|
|
||||||
shapedType, numElements, static_cast<const float *>(tensorData));
|
|
||||||
break;
|
|
||||||
case ScalarType::Double:
|
|
||||||
valueAttribute = mlirDenseElementsAttrDoubleGet(
|
|
||||||
shapedType, numElements, static_cast<const double *>(tensorData));
|
|
||||||
break;
|
|
||||||
case ScalarType::Bool:
|
|
||||||
// TODO: Add a test specifically for bool and ensure consistency between
|
|
||||||
// storage format and load format
|
|
||||||
// (https://github.com/llvm/mlir-npcomp/issues/100).
|
|
||||||
valueAttribute = mlirDenseElementsAttrBoolGet(
|
|
||||||
shapedType, numElements, static_cast<const int *>(tensorData));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throwUnsupportedTensorError();
|
|
||||||
}
|
|
||||||
MlirValue constTensorValue =
|
MlirValue constTensorValue =
|
||||||
funcBuilder->getGeneralConstant(loc, valueAttribute);
|
funcBuilder->getGeneralConstant(loc, valueAttribute);
|
||||||
|
|
||||||
// Create an array from the tensor constant via the
|
// Create an array from the tensor constant via the
|
||||||
// numpy.create_array_from_tensor op.
|
// numpy.create_array_from_tensor op.
|
||||||
MlirType constArrayType = npcompNdArrayTypeGetFromShaped(shapedType);
|
MlirType constArrayType =
|
||||||
|
npcompNdArrayTypeGetFromShaped(mlirAttributeGetType(valueAttribute));
|
||||||
MlirOperationState state = mlirOperationStateGet(
|
MlirOperationState state = mlirOperationStateGet(
|
||||||
toMlirStringRef("numpy.create_array_from_tensor"), loc);
|
toMlirStringRef("numpy.create_array_from_tensor"), loc);
|
||||||
mlirOperationStateAddOperands(&state, 1, &constTensorValue);
|
mlirOperationStateAddOperands(&state, 1, &constTensorValue);
|
||||||
|
|
|
@ -87,116 +87,6 @@ void KernelCallBuilder::addResultType(MlirType resultType) {
|
||||||
|
|
||||||
MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
|
MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
|
||||||
|
|
||||||
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
|
|
||||||
auto type = rawMapFromTorchScalarType(scalarType);
|
|
||||||
if (mlirTypeIsNull(type)) {
|
|
||||||
std::stringstream message;
|
|
||||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
|
||||||
throw std::invalid_argument(message.str());
|
|
||||||
}
|
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
|
|
||||||
c10::ScalarType scalarType) {
|
|
||||||
auto type = rawMapFromTorchScalarType(scalarType);
|
|
||||||
if (mlirTypeIsNull(type)) {
|
|
||||||
std::stringstream message;
|
|
||||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
|
||||||
mlirEmitError(loc, message.str().c_str());
|
|
||||||
}
|
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|
||||||
using c10::ScalarType;
|
|
||||||
switch (scalarType) {
|
|
||||||
case ScalarType::Byte:
|
|
||||||
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
|
|
||||||
return mlirIntegerTypeGet(context, 8);
|
|
||||||
case ScalarType::Char:
|
|
||||||
return mlirIntegerTypeGet(context, 8);
|
|
||||||
case ScalarType::Short:
|
|
||||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
||||||
return mlirIntegerTypeGet(context, 16);
|
|
||||||
case ScalarType::Int:
|
|
||||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
||||||
return mlirIntegerTypeGet(context, 32);
|
|
||||||
case ScalarType::Long:
|
|
||||||
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
||||||
return mlirIntegerTypeGet(context, 64);
|
|
||||||
case ScalarType::Bool:
|
|
||||||
return npcompBoolTypeGet(context);
|
|
||||||
case ScalarType::Double:
|
|
||||||
return mlirF64TypeGet(context);
|
|
||||||
case ScalarType::Float:
|
|
||||||
return mlirF32TypeGet(context);
|
|
||||||
case ScalarType::BFloat16:
|
|
||||||
return mlirBF16TypeGet(context);
|
|
||||||
case ScalarType::Half:
|
|
||||||
return mlirF16TypeGet(context);
|
|
||||||
default: {
|
|
||||||
return {nullptr};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|
||||||
const c10::TypePtr &torchType) {
|
|
||||||
using c10::TypeKind;
|
|
||||||
auto kind = torchType->kind();
|
|
||||||
switch (kind) {
|
|
||||||
case TypeKind::TensorType: {
|
|
||||||
auto tensorType = torchType->cast<c10::TensorType>();
|
|
||||||
// Element type.
|
|
||||||
MlirType elementType;
|
|
||||||
if (tensorType->scalarType()) {
|
|
||||||
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
|
||||||
if (mlirTypeIsNull(elementType))
|
|
||||||
return {nullptr};
|
|
||||||
} else {
|
|
||||||
elementType = npcompAnyDtypeTypeGet(context);
|
|
||||||
}
|
|
||||||
// Sizes.
|
|
||||||
auto &sizes = tensorType->symbolic_sizes();
|
|
||||||
if (!sizes.rank()) {
|
|
||||||
// Unranked.
|
|
||||||
return npcompNdArrayTypeGetUnranked(elementType);
|
|
||||||
}
|
|
||||||
// Ranked with possibly dynamic dims.
|
|
||||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
|
||||||
std::vector<int64_t> dims;
|
|
||||||
dims.resize(*sizes.rank());
|
|
||||||
for (size_t i = 0; i < dims.size(); ++i) {
|
|
||||||
auto shapeSymbol = symbolicShape[i];
|
|
||||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
|
||||||
}
|
|
||||||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
std::stringstream message;
|
|
||||||
message << "unable to map Torch type " << torchType << " to MLIR type";
|
|
||||||
mlirEmitError(loc, message.str().c_str());
|
|
||||||
return {nullptr};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|
||||||
if (!tensor.defined()) {
|
|
||||||
// Undefined tensors are equivalent to None.
|
|
||||||
// This may need to be re-evaluated at some point.
|
|
||||||
return npcompNoneTypeGet(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
|
||||||
// TODO: Decide when it is necessary to take strides into account. Right now,
|
|
||||||
// just erase them and let the compiler decide.
|
|
||||||
|
|
||||||
auto sizes = tensor.sizes();
|
|
||||||
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<FuncBuilder>
|
std::unique_ptr<FuncBuilder>
|
||||||
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||||
MlirLocation location, const std::string &name,
|
MlirLocation location, const std::string &name,
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_FUNC_BUILDER_H
|
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_FUNC_BUILDER_H
|
||||||
|
|
||||||
#include "mlir_utils.h"
|
#include "mlir_utils.h"
|
||||||
|
#include "torch_to_mlir_utils.h"
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
@ -45,34 +46,6 @@ private:
|
||||||
bool owned = true;
|
bool owned = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Maps various runtime types to MlirType.
|
|
||||||
class TypeMapper {
|
|
||||||
public:
|
|
||||||
TypeMapper(MlirContext context) : context(context) {}
|
|
||||||
|
|
||||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
|
||||||
/// Throws std::invalid_argument on failure.
|
|
||||||
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
|
|
||||||
|
|
||||||
/// Gets a corresponding MlirType for the forward component of a tensor.
|
|
||||||
/// Throws std::invalid_argument on failure.
|
|
||||||
MlirType forwardTensorToType(at::Tensor tensor);
|
|
||||||
|
|
||||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
|
||||||
/// Returns a null type on failure and emits a diagnostic.
|
|
||||||
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
|
||||||
|
|
||||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
|
||||||
/// on failure and emits a diagnostic.
|
|
||||||
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Maps from a scalar type and returns null if no match (no other error
|
|
||||||
/// reporting).
|
|
||||||
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
|
|
||||||
MlirContext context;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Wraps an MlirBlock under construction, primarily tracking the terminator
|
/// Wraps an MlirBlock under construction, primarily tracking the terminator
|
||||||
/// and supporting manipulation of it. The terminator may be null if it has
|
/// and supporting manipulation of it. The terminator may be null if it has
|
||||||
/// not yet been constructed.
|
/// not yet been constructed.
|
||||||
|
|
|
@ -20,8 +20,10 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
/// Main entry-point for importing torch::jit::Graph instances (and structures
|
/// Main entry-point for importing torch::jit::Graph instances.
|
||||||
/// surrounding them such as modules and methods).
|
///
|
||||||
|
/// This code doesn't handle importing of torch::jit::Module's. See
|
||||||
|
/// ModuleImporter for that.
|
||||||
///
|
///
|
||||||
/// In torch terminology, a Graph is a function. Later in the compiler, we may
|
/// In torch terminology, a Graph is a function. Later in the compiler, we may
|
||||||
/// specialize multiple versions of it.
|
/// specialize multiple versions of it.
|
||||||
|
|
|
@ -0,0 +1,177 @@
|
||||||
|
//===- ivalue_importer.cpp ------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "ivalue_importer.h"
|
||||||
|
#include "graph_importer.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlir_utils.h"
|
||||||
|
|
||||||
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
|
#include "mlir-c/Diagnostics.h"
|
||||||
|
#include "npcomp-c/Types.h"
|
||||||
|
|
||||||
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Helper class for holding state during recursive IValue import.
|
||||||
|
///
|
||||||
|
/// The intended usage pattern of this class is to construct it then call
|
||||||
|
/// `importIValue` exactly once. Calling `importIValue` more than once
|
||||||
|
/// is likely to produce unexpected results since the same in-memory IValue
|
||||||
|
/// can be reimported more than once. That is, subsequent calls to
|
||||||
|
/// `importIValue` will not properly unify IValue's with already-imported
|
||||||
|
/// IValue's.
|
||||||
|
///
|
||||||
|
/// TODO: Support unifying repeated IValue's.
|
||||||
|
/// This already is an issue when importing a single IValue through the current
|
||||||
|
/// API, because the same underlying Tensor (or List/Dict) can be referenced by
|
||||||
|
/// multiple properties of a module. There is an extra complication with tensors
|
||||||
|
/// because they can alias each other in fairly arbitrary ways, which we will
|
||||||
|
/// need to model with slice ops.
|
||||||
|
class IValueImporter {
|
||||||
|
public:
|
||||||
|
IValueImporter(MlirBlock importBlock, MlirContext context)
|
||||||
|
: importBlock(importBlock), context(context), typeMapper(context) {}
|
||||||
|
|
||||||
|
MlirValue importIValue(c10::IValue value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
MlirValue importModule(torch::jit::Module jitModule);
|
||||||
|
void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody);
|
||||||
|
|
||||||
|
MlirBlock importBlock;
|
||||||
|
MlirContext context;
|
||||||
|
TypeMapper typeMapper;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
|
// TODO: Can we do better?
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
MlirOperation nnModule =
|
||||||
|
createMlirOperation("torch.nn_module", loc,
|
||||||
|
npcompNnModuleTypeGet(context), mlirRegionCreate());
|
||||||
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||||
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
|
|
||||||
|
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
||||||
|
const std::vector<c10::ClassAttribute> &classAttributes =
|
||||||
|
currentModule.type()->getAttributes();
|
||||||
|
assert(slots.size() == classAttributes.size() &&
|
||||||
|
"mismatch between object and type!");
|
||||||
|
for (int i = 0, e = slots.size(); i < e; i++) {
|
||||||
|
const c10::ClassAttribute &classAttribute = classAttributes[i];
|
||||||
|
MlirValue slotValue = importIValue(slots[i]);
|
||||||
|
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
||||||
|
createMlirOperationAtEnd(
|
||||||
|
nnModuleBody, "torch.attr", loc, slotValue,
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"name", mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(classAttribute.getName()))));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (torch::jit::Function *function : currentModule.type()->methods()) {
|
||||||
|
importMethod(function, nnModuleBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
||||||
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
|
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
||||||
|
return mlirOperationGetResult(nnModule, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
|
||||||
|
// TODO: Can we do better?
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
if (ivalue.isBool()) {
|
||||||
|
MlirType type = npcompBoolTypeGet(context);
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
importBlock, "basicpy.bool_constant", loc, type,
|
||||||
|
toMlirNamedAttribute("value",
|
||||||
|
mlirBoolAttrGet(context, ivalue.toBool())));
|
||||||
|
return mlirOperationGetResult(operation, 0);
|
||||||
|
}
|
||||||
|
if (ivalue.isDouble()) {
|
||||||
|
MlirType type = mlirF64TypeGet(context);
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
importBlock, "basicpy.numeric_constant", loc, type,
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble())));
|
||||||
|
return mlirOperationGetResult(operation, 0);
|
||||||
|
}
|
||||||
|
if (ivalue.isInt()) {
|
||||||
|
MlirType type = mlirIntegerTypeGet(context, 64);
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
importBlock, "basicpy.numeric_constant", loc, type,
|
||||||
|
toMlirNamedAttribute("value",
|
||||||
|
mlirIntegerAttrGet(type, ivalue.toInt())));
|
||||||
|
return mlirOperationGetResult(operation, 0);
|
||||||
|
}
|
||||||
|
if (ivalue.isTensor()) {
|
||||||
|
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||||
|
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
||||||
|
MlirOperation constant = createMlirOperationAtEnd(
|
||||||
|
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
|
||||||
|
toMlirNamedAttribute("value", denseElements));
|
||||||
|
MlirOperation ndarray = createMlirOperationAtEnd(
|
||||||
|
importBlock, "numpy.create_array_from_tensor", loc,
|
||||||
|
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
|
||||||
|
mlirOperationGetResult(constant, 0));
|
||||||
|
return mlirOperationGetResult(ndarray, 0);
|
||||||
|
}
|
||||||
|
if (ivalue.isModule()) {
|
||||||
|
return importModule(ivalue.toModule());
|
||||||
|
}
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Unsupported ivalue: " << ivalue;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
|
MlirBlock nnModuleBody) {
|
||||||
|
// TODO: Can we do better?
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
FuncBuilder::Inserter inserter = [&](MlirOperation func) {
|
||||||
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
|
importBlock, mlirBlockGetTerminator(importBlock), func);
|
||||||
|
// TODO: This should probably be a flag in MlirMappingOptions.
|
||||||
|
mlirOperationSetAttributeByName(
|
||||||
|
func, toMlirStringRef("sym_visibility"),
|
||||||
|
mlirStringAttrGet(context, toMlirStringRef("private")));
|
||||||
|
};
|
||||||
|
// We make an effort for the func op's symbol name to be useful for debugging,
|
||||||
|
// but still clearly non-load-bearing.
|
||||||
|
std::string symName =
|
||||||
|
"__npcomp_priv_fn." + function->qualname().qualifiedName();
|
||||||
|
GraphImporter::MlirMappingOptions mappingOptions{context, symName, symName,
|
||||||
|
typeMapper, inserter};
|
||||||
|
GraphImporter importer(function->graph(), mappingOptions);
|
||||||
|
importer.initialize();
|
||||||
|
importer.importGenericFunc();
|
||||||
|
createMlirOperationAtEnd(
|
||||||
|
nnModuleBody, "torch.method", loc,
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"name",
|
||||||
|
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
||||||
|
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
|
||||||
|
context, toMlirStringRef(symName))));
|
||||||
|
}
|
||||||
|
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||||
|
MlirContext context) {
|
||||||
|
// When debugging module importing, it can be useful to dump as so:
|
||||||
|
// if (ivalue.isModule())
|
||||||
|
// ivalue.toModule().dump(true, true, true);
|
||||||
|
IValueImporter importer(block, context);
|
||||||
|
importer.importIValue(ivalue);
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===- ivalue_importer.h ----------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H
|
||||||
|
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../pybind.h"
|
||||||
|
#include "func_builder.h"
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/api/module.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Main entry-point for importing torch IValue's .
|
||||||
|
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
||||||
|
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context);
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H
|
|
@ -31,6 +31,52 @@ inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
|
||||||
return mlirNamedAttributeGet(ident, attr);
|
return mlirNamedAttributeGet(ident, attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
|
MlirNamedAttribute namedAttr) {
|
||||||
|
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
|
MlirRegion region) {
|
||||||
|
mlirOperationStateAddOwnedRegions(&state, 1, ®ion);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
|
MlirValue value) {
|
||||||
|
mlirOperationStateAddOperands(&state, 1, &value);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void addToMlirOperationState(MlirOperationState &state,
|
||||||
|
MlirType resultType) {
|
||||||
|
mlirOperationStateAddResults(&state, 1, &resultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... Ts>
|
||||||
|
void addToMlirOperationState(MlirOperationState &state, T &&t, Ts &&...ts) {
|
||||||
|
addToMlirOperationState(state, t);
|
||||||
|
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void addToMlirOperationState(MlirOperationState &state) {}
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
MlirOperation createMlirOperation(std::string name, MlirLocation loc,
|
||||||
|
Ts &&...ts) {
|
||||||
|
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
|
||||||
|
addToMlirOperationState(state, std::forward<Ts>(ts)...);
|
||||||
|
return mlirOperationCreate(&state);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
|
||||||
|
MlirLocation loc, Ts &&...ts) {
|
||||||
|
MlirOperation operation =
|
||||||
|
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
|
||||||
|
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
|
||||||
|
operation);
|
||||||
|
return operation;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H
|
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
#include "graph_importer.h"
|
#include "graph_importer.h"
|
||||||
|
#include "ivalue_importer.h"
|
||||||
|
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
@ -109,6 +110,11 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModuleBuilder::importModule(torch::jit::Module jitModule) {
|
||||||
|
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
|
||||||
|
mlirModuleGetContext(module));
|
||||||
|
}
|
||||||
|
|
||||||
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
||||||
MlirBlock block = getBodyBlock();
|
MlirBlock block = getBodyBlock();
|
||||||
MlirOperation terminator = this->terminator;
|
MlirOperation terminator = this->terminator;
|
||||||
|
@ -129,5 +135,6 @@ void ModuleBuilder::bind(py::module &m) {
|
||||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||||
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
||||||
py::keep_alive<0, 1>())
|
py::keep_alive<0, 1>())
|
||||||
.def("import_function", &ModuleBuilder::importFunction);
|
.def("import_function", &ModuleBuilder::importFunction)
|
||||||
|
.def("import_module", &ModuleBuilder::importModule);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
@ -43,6 +44,9 @@ public:
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr
|
||||||
importFunction(torch::jit::StrongFunctionPtr function);
|
importFunction(torch::jit::StrongFunctionPtr function);
|
||||||
|
|
||||||
|
// Imports a torch::jit::Module into the current module.
|
||||||
|
void importModule(torch::jit::Module jitModule);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncBuilder::Inserter createInserter();
|
FuncBuilder::Inserter createInserter();
|
||||||
MlirBlock getBodyBlock();
|
MlirBlock getBodyBlock();
|
||||||
|
|
|
@ -0,0 +1,212 @@
|
||||||
|
//===- ivalue_importer.cpp ------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "graph_importer.h"
|
||||||
|
#include "ivalue_importer.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlir_utils.h"
|
||||||
|
|
||||||
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
|
#include "mlir-c/Diagnostics.h"
|
||||||
|
#include "npcomp-c/Types.h"
|
||||||
|
|
||||||
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
|
auto type = rawMapFromTorchScalarType(scalarType);
|
||||||
|
if (mlirTypeIsNull(type)) {
|
||||||
|
std::stringstream message;
|
||||||
|
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||||
|
throw std::invalid_argument(message.str());
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
|
||||||
|
c10::ScalarType scalarType) {
|
||||||
|
auto type = rawMapFromTorchScalarType(scalarType);
|
||||||
|
if (mlirTypeIsNull(type)) {
|
||||||
|
std::stringstream message;
|
||||||
|
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||||
|
mlirEmitError(loc, message.str().c_str());
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
|
using c10::ScalarType;
|
||||||
|
switch (scalarType) {
|
||||||
|
case ScalarType::Byte:
|
||||||
|
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
|
||||||
|
return mlirIntegerTypeGet(context, 8);
|
||||||
|
case ScalarType::Char:
|
||||||
|
return mlirIntegerTypeGet(context, 8);
|
||||||
|
case ScalarType::Short:
|
||||||
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||||
|
return mlirIntegerTypeGet(context, 16);
|
||||||
|
case ScalarType::Int:
|
||||||
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||||
|
return mlirIntegerTypeGet(context, 32);
|
||||||
|
case ScalarType::Long:
|
||||||
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
||||||
|
return mlirIntegerTypeGet(context, 64);
|
||||||
|
case ScalarType::Bool:
|
||||||
|
return npcompBoolTypeGet(context);
|
||||||
|
case ScalarType::Double:
|
||||||
|
return mlirF64TypeGet(context);
|
||||||
|
case ScalarType::Float:
|
||||||
|
return mlirF32TypeGet(context);
|
||||||
|
case ScalarType::BFloat16:
|
||||||
|
return mlirBF16TypeGet(context);
|
||||||
|
case ScalarType::Half:
|
||||||
|
return mlirF16TypeGet(context);
|
||||||
|
default: {
|
||||||
|
return {nullptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
|
const c10::TypePtr &torchType) {
|
||||||
|
using c10::TypeKind;
|
||||||
|
auto kind = torchType->kind();
|
||||||
|
switch (kind) {
|
||||||
|
case TypeKind::TensorType: {
|
||||||
|
auto tensorType = torchType->cast<c10::TensorType>();
|
||||||
|
// Element type.
|
||||||
|
MlirType elementType;
|
||||||
|
if (tensorType->scalarType()) {
|
||||||
|
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
||||||
|
if (mlirTypeIsNull(elementType))
|
||||||
|
return {nullptr};
|
||||||
|
} else {
|
||||||
|
elementType = npcompAnyDtypeTypeGet(context);
|
||||||
|
}
|
||||||
|
// Sizes.
|
||||||
|
auto &sizes = tensorType->symbolic_sizes();
|
||||||
|
if (!sizes.rank()) {
|
||||||
|
// Unranked.
|
||||||
|
return npcompNdArrayTypeGetUnranked(elementType);
|
||||||
|
}
|
||||||
|
// Ranked with possibly dynamic dims.
|
||||||
|
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||||
|
std::vector<int64_t> dims;
|
||||||
|
dims.resize(*sizes.rank());
|
||||||
|
for (size_t i = 0; i < dims.size(); ++i) {
|
||||||
|
auto shapeSymbol = symbolicShape[i];
|
||||||
|
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||||
|
}
|
||||||
|
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||||
|
}
|
||||||
|
case TypeKind::ClassType: {
|
||||||
|
return npcompNnModuleTypeGet(context);
|
||||||
|
}
|
||||||
|
case TypeKind::FloatType: {
|
||||||
|
return mlirF64TypeGet(context);
|
||||||
|
}
|
||||||
|
case TypeKind::IntType: {
|
||||||
|
return mlirIntegerTypeGet(context, 64);
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
std::stringstream message;
|
||||||
|
message << "unable to map Torch type " << *torchType << " to MLIR type";
|
||||||
|
mlirEmitError(loc, message.str().c_str());
|
||||||
|
return {nullptr};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||||
|
if (!tensor.defined()) {
|
||||||
|
// Undefined tensors are equivalent to None.
|
||||||
|
// This may need to be re-evaluated at some point.
|
||||||
|
return npcompNoneTypeGet(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
||||||
|
// TODO: Decide when it is necessary to take strides into account. Right now,
|
||||||
|
// just erase them and let the compiler decide.
|
||||||
|
|
||||||
|
auto sizes = tensor.sizes();
|
||||||
|
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
|
MlirLocation loc) {
|
||||||
|
MlirContext context = mlirLocationGetContext(loc);
|
||||||
|
TypeMapper typeMapper(context);
|
||||||
|
using at::ScalarType;
|
||||||
|
|
||||||
|
auto throwUnsupportedTensorError = [&]() {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Unsupported import tensor type: " << tensor;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
|
||||||
|
if (!tensor.is_contiguous())
|
||||||
|
tensor = tensor.contiguous();
|
||||||
|
|
||||||
|
// The flat number of bytes throws an exception for tensors that are not
|
||||||
|
// dense and accessible as such.
|
||||||
|
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
|
||||||
|
c10::Layout::Strided);
|
||||||
|
|
||||||
|
// Construct the ShapedType.
|
||||||
|
MlirType elementType;
|
||||||
|
if (tensor.scalar_type() == ScalarType::Bool) {
|
||||||
|
// Bool is a special case. When used as an element type, it must be i1.
|
||||||
|
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
||||||
|
// Basicpy bool type.
|
||||||
|
elementType = mlirIntegerTypeGet(context, 1);
|
||||||
|
} else {
|
||||||
|
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
||||||
|
}
|
||||||
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||||
|
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||||
|
shape.size(), shape.data(), elementType, loc);
|
||||||
|
if (mlirTypeIsNull(shapedType)) {
|
||||||
|
throwUnsupportedTensorError();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import DenseElementsAttr data.
|
||||||
|
// TODO: Support bool tensors.
|
||||||
|
// TODO: More import formats in C-API.
|
||||||
|
auto numElements = tensor.numel();
|
||||||
|
auto tensorData = tensor.data_ptr();
|
||||||
|
switch (tensor.scalar_type()) {
|
||||||
|
case ScalarType::Int:
|
||||||
|
return mlirDenseElementsAttrInt32Get(
|
||||||
|
shapedType, numElements, static_cast<const int32_t *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Long:
|
||||||
|
return mlirDenseElementsAttrInt64Get(
|
||||||
|
shapedType, numElements, static_cast<const int64_t *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Float:
|
||||||
|
return mlirDenseElementsAttrFloatGet(
|
||||||
|
shapedType, numElements, static_cast<const float *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Double:
|
||||||
|
return mlirDenseElementsAttrDoubleGet(
|
||||||
|
shapedType, numElements, static_cast<const double *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Bool:
|
||||||
|
// TODO: Add a test specifically for bool and ensure consistency between
|
||||||
|
// storage format and load format
|
||||||
|
// (https://github.com/llvm/mlir-npcomp/issues/100).
|
||||||
|
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
|
||||||
|
static_cast<const int *>(tensorData));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throwUnsupportedTensorError();
|
||||||
|
}
|
||||||
|
return {nullptr}; // Unreachable.
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
//===- torch_to_mlir_utils.h ------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||||
|
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../pybind.h"
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/api/module.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Maps various runtime types to MlirType.
|
||||||
|
class TypeMapper {
|
||||||
|
public:
|
||||||
|
TypeMapper(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
|
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||||
|
/// Throws std::invalid_argument on failure.
|
||||||
|
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
|
||||||
|
|
||||||
|
/// Gets a corresponding MlirType for the forward component of a tensor.
|
||||||
|
/// Throws std::invalid_argument on failure.
|
||||||
|
MlirType forwardTensorToType(at::Tensor tensor);
|
||||||
|
|
||||||
|
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||||
|
/// Torch ScalarType is used to represent the possible element types of Torch
|
||||||
|
/// tensors, which is different from the set of types used to represent
|
||||||
|
/// Python numeric scalar values (which are always either f64 or i64).
|
||||||
|
/// Returns a null type on failure and emits a diagnostic.
|
||||||
|
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
||||||
|
|
||||||
|
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||||
|
/// on failure and emits a diagnostic.
|
||||||
|
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Maps from a scalar type and returns null if no match (no other error
|
||||||
|
/// reporting).
|
||||||
|
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
|
||||||
|
MlirContext context;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||||
|
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
|
MlirLocation loc);
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
|
@ -2,26 +2,21 @@
|
||||||
# This file is licensed under a pytorch-style license
|
# This file is licensed under a pytorch-style license
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
# RUN: %PYTHON %s
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
class ExampleClass:
|
|
||||||
def __init__(self, x):
|
|
||||||
self.x = x
|
|
||||||
|
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# For now, TorchScript classes are wholly unsupported, so use it to test
|
# To test errors, use a type that we don't support yet.
|
||||||
# type conversion errors.
|
|
||||||
try:
|
try:
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def import_class(c: ExampleClass):
|
def import_class(x: typing.Any):
|
||||||
return c.x
|
return x
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||||
assert str(e) == "could not convert function input type"
|
assert str(e) == "could not convert function input type"
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class Submodule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.s = Submodule()
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
|
||||||
|
# However, we do make an attempt to ensure that the names are debuggable.
|
||||||
|
#
|
||||||
|
# The names have the following structure:
|
||||||
|
# - `__npcomp_priv_fn`: marker that this symbol name is private to npcomp
|
||||||
|
# - `__torch__.Submodule.forward`: the name that TorchScript gives the function
|
||||||
|
# - For those curious, the `__torch__` would be the Python module name, but in
|
||||||
|
# the case that the name is `__main__` Torch replaces it with `__torch__` to
|
||||||
|
# avoid collisions.
|
||||||
|
|
||||||
|
# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward
|
||||||
|
# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -0,0 +1,26 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x, y):
|
||||||
|
# CHECK-LABEL: torch.nn_module
|
||||||
|
# CHECK: loc("{{.*}}methods-locations.py":[[@LINE+1]]
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print(enable_debug_info=True)
|
|
@ -0,0 +1,39 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
|
||||||
|
|
||||||
|
# CHECK-LABEL: func private
|
||||||
|
# CHECK-SAME: @[[SYMNAME:.*]](
|
||||||
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
|
||||||
|
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||||
|
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
|
||||||
|
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
|
||||||
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.method "forward", @[[SYMNAME]]
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -0,0 +1,34 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.i = 3
|
||||||
|
self.f = 42.5
|
||||||
|
|
||||||
|
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||||
|
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
|
||||||
|
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
|
# CHECK: %[[MODULE:.*]] = torch.nn_module {
|
||||||
|
# Note: for some reason, Torch always adds a "training" property to all modules.
|
||||||
|
# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType
|
||||||
|
# CHECK: torch.attr "i", %[[N3]] : i64
|
||||||
|
# CHECK: torch.attr "f", %[[N42]] : f64
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -0,0 +1,52 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class Submodule(torch.nn.Module):
|
||||||
|
def __init__(self, n):
|
||||||
|
super().__init__()
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.s0 = Submodule(0)
|
||||||
|
self.s1 = Submodule(1)
|
||||||
|
|
||||||
|
# CHECK: %[[T:.*]] = basicpy.bool_constant true
|
||||||
|
|
||||||
|
# CHECK: %[[T1:.*]] = basicpy.bool_constant true
|
||||||
|
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
|
||||||
|
# CHECK: %[[S0:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.attr "training", %[[T1]] : !basicpy.BoolType
|
||||||
|
# CHECK: torch.attr "n", %[[N0]] : i64
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
# CHECK: %[[T2:.*]] = basicpy.bool_constant true
|
||||||
|
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||||
|
# CHECK: %[[S1:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.attr "training", %[[T2]] : !basicpy.BoolType
|
||||||
|
# CHECK: torch.attr "n", %[[N1]] : i64
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
||||||
|
# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module
|
||||||
|
# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -0,0 +1,35 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# TODO: Test (and make work) tensors that alias each other.
|
||||||
|
self.t = torch.ones(1)
|
||||||
|
self.p = torch.nn.Parameter(torch.arange(3.0))
|
||||||
|
|
||||||
|
# CHECK: %[[CP:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>
|
||||||
|
# CHECK: %[[P:.*]] = numpy.create_array_from_tensor %[[CP]] : (tensor<3xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
|
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
|
# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -107,6 +107,16 @@ int npcompTypeIsATuple(MlirType t);
|
||||||
/** Gets the generic Python "tuple" type. */
|
/** Gets the generic Python "tuple" type. */
|
||||||
MlirType npcompTupleTypeGet(MlirContext context);
|
MlirType npcompTupleTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.nn.Module type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a torch.nn.Module type */
|
||||||
|
int npcompTypeIsANnModule(MlirType t);
|
||||||
|
|
||||||
|
/** Gets the singleton torch.nn.Module type. */
|
||||||
|
MlirType npcompNnModuleTypeGet(MlirContext context);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
include "npcomp/Dialect/ATen/IR/ATenDialect.td"
|
include "npcomp/Dialect/ATen/IR/ATenDialect.td"
|
||||||
include "npcomp/Dialect/ATen/IR/ATenOpInterface.td"
|
include "npcomp/Dialect/ATen/IR/ATenOpInterface.td"
|
||||||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||||
include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||||
|
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,11 @@ mlir_tablegen(TorchDialect.h.inc -gen-dialect-decls -dialect=torch)
|
||||||
add_public_tablegen_target(MLIRTorchOpsIncGen)
|
add_public_tablegen_target(MLIRTorchOpsIncGen)
|
||||||
add_dependencies(mlir-headers MLIRTorchOpsIncGen)
|
add_dependencies(mlir-headers MLIRTorchOpsIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TorchTypes.td)
|
||||||
|
mlir_tablegen(TorchTypes.h.inc -gen-typedef-decls)
|
||||||
|
mlir_tablegen(TorchTypes.cpp.inc -gen-typedef-defs)
|
||||||
|
add_public_tablegen_target(MLIRTorchTypesIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS OpInterfaces.td)
|
set(LLVM_TARGET_DEFINITIONS OpInterfaces.td)
|
||||||
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
|
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
|
||||||
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
|
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||||
|
|
|
@ -26,7 +26,7 @@ def Torch_Dialect : Dialect {
|
||||||
- Transitions between mutable and immutable tensors.
|
- Transitions between mutable and immutable tensors.
|
||||||
- Gradient associations and management.
|
- Gradient associations and management.
|
||||||
- Custom ops.
|
- Custom ops.
|
||||||
- Types specific to PyTorch.
|
- Types specific to PyTorch such as torch.nn.Module structures
|
||||||
- Module level constructs like quantization parameters, globals, etc.
|
- Module level constructs like quantization parameters, globals, etc.
|
||||||
|
|
||||||
Where possible, this dialect composes with types and ops from the `Numpy`
|
Where possible, this dialect composes with types and ops from the `Numpy`
|
||||||
|
@ -40,92 +40,4 @@ def Torch_Dialect : Dialect {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Type predicates
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Torch has a fairly advanced and featureful Tensor type, and some of the
|
|
||||||
// semantics are important to preserve in a compilation context. In the future,
|
|
||||||
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
|
|
||||||
// and interop are well served by existing tensor-like types, which are
|
|
||||||
// specifically permitted. Typically, on import, constraints are fairly loose
|
|
||||||
// and based on how the program is captured. Settling on and refining to
|
|
||||||
// specific types is done as part of lowering.
|
|
||||||
//
|
|
||||||
// While lowering it is useful to be able to distinguish between mutable and
|
|
||||||
// immutable tensors:
|
|
||||||
// - Mutable tensors can alias.
|
|
||||||
// - Mutable tensors can be a view over another mutable tensor.
|
|
||||||
// - Mutable tensors act as if reference counted and exist for the lifetime
|
|
||||||
// of any reference or derived view.
|
|
||||||
// Conversely, immutable tensors:
|
|
||||||
// - Are normal SSA values representing the contents of the tensor.
|
|
||||||
// - Cannot alias.
|
|
||||||
// - Cannot be a view of any mutable value.
|
|
||||||
// - Have undefined lifetimes.
|
|
||||||
//
|
|
||||||
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
|
|
||||||
// however, when lowering to specific ops, further constraints are introduced,
|
|
||||||
// necessitating copies, loads, and stores to be inserted to bridge worlds.
|
|
||||||
def AnyTorchImmutableTensor : AnyTypeOf<[
|
|
||||||
// Normal MLIR immutable tensors.
|
|
||||||
AnyTensor,
|
|
||||||
], "allowable torch immutable tensor">;
|
|
||||||
|
|
||||||
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
|
|
||||||
AnyTorchImmutableTensor,
|
|
||||||
Basicpy_NoneType,
|
|
||||||
], "allowable torch immutable tensor (or None)">;
|
|
||||||
|
|
||||||
def AnyTorchMutableTensor : AnyTypeOf<[
|
|
||||||
// "Numpy-style" mutable NDArray. While not offering the full generality
|
|
||||||
// of a Torch tensor, it models the same access patterns and implies the
|
|
||||||
// same aliasing as Torch tensors.
|
|
||||||
Numpy_NdArrayType,
|
|
||||||
], "allowable torch mutable tensor">;
|
|
||||||
|
|
||||||
def AnyTorchTensorType : AnyTypeOf<[
|
|
||||||
AnyTorchImmutableTensor,
|
|
||||||
AnyTorchMutableTensor,
|
|
||||||
], "Any tensor type legal to pass to a Torch kernel">;
|
|
||||||
|
|
||||||
def AnyTorchScalarType : AnyTypeOf<[
|
|
||||||
AnySignedInteger,
|
|
||||||
AnyFloat,
|
|
||||||
Basicpy_BoolType,
|
|
||||||
Basicpy_StrType,
|
|
||||||
Basicpy_NoneType,
|
|
||||||
// Allow signless integers for ease of conversions. In general, this
|
|
||||||
// dialect uses signed integers.
|
|
||||||
AnySignlessInteger,
|
|
||||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
|
||||||
|
|
||||||
def AnyTorchBoolType : AnyTypeOf<[
|
|
||||||
I1,
|
|
||||||
Basicpy_BoolType,
|
|
||||||
], "Any permissible bool type">;
|
|
||||||
|
|
||||||
def AnyTorchBoolListType : AnyTypeOf<[
|
|
||||||
Basicpy_ListType,
|
|
||||||
// TODO: Support typed list when available.
|
|
||||||
], "Any bool list type (bool[])">;
|
|
||||||
|
|
||||||
def AnyTorchIntType : AnyTypeOf<[
|
|
||||||
AnySignedInteger,
|
|
||||||
AnySignlessInteger,
|
|
||||||
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
|
|
||||||
|
|
||||||
def AnyTorchIntListType : AnyTypeOf<[
|
|
||||||
Basicpy_ListType,
|
|
||||||
// TODO: Support typed list when available.
|
|
||||||
], "Any int list type (int[])">;
|
|
||||||
|
|
||||||
def AnyTorchType : AnyTypeOf<[
|
|
||||||
AnyTorchBoolType,
|
|
||||||
AnyTorchScalarType,
|
|
||||||
AnyTorchTensorType,
|
|
||||||
Basicpy_ListType,
|
|
||||||
Basicpy_NoneType,
|
|
||||||
], "Any type that is legal to pass to a Torch kernel">;
|
|
||||||
|
|
||||||
#endif // TORCH_BASE
|
#endif // TORCH_BASE
|
||||||
|
|
|
@ -12,7 +12,9 @@
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
||||||
|
|
|
@ -9,8 +9,9 @@
|
||||||
#ifndef TORCH_OPS
|
#ifndef TORCH_OPS
|
||||||
#define TORCH_OPS
|
#define TORCH_OPS
|
||||||
|
|
||||||
include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Torch_Dialect, mnemonic, traits> {
|
: Op<Torch_Dialect, mnemonic, traits> {
|
||||||
|
@ -47,4 +48,87 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TorchScript modeling ops.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
||||||
|
let summary = "Constructs a torch.nn.Module";
|
||||||
|
let description = [{
|
||||||
|
This op is used to represent a torch.nn.Module when importing a
|
||||||
|
graph of Python objects.
|
||||||
|
|
||||||
|
This op returns a new torch.nn.Module as an SSA value, with a set of
|
||||||
|
declaratively specified properties.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%2 = torch.nn_module {
|
||||||
|
torch.attr "b", %bool_true : !basicpy.BoolType
|
||||||
|
torch.attr "i", %num3_i64 : i64
|
||||||
|
torch.attr "f", %num : f64
|
||||||
|
torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.attr "submodule", %1 : !torch.nn.Module
|
||||||
|
torch.method "method", @f
|
||||||
|
}
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs Torch_NnModuleType:$result);
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
let verifier = "return ::verify(*this);";
|
||||||
|
|
||||||
|
let assemblyFormat = "$region attr-dict";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||||
|
let summary = "Implicit terminator for torch.nn_module";
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||||
|
let summary = "Define an attribute of a torch.nn.Module";
|
||||||
|
let description = [{
|
||||||
|
This op declaratively specifies that the parent torch.nn_module has an
|
||||||
|
attribute `name` with value `value`, which is allowed to be an arbitrary
|
||||||
|
Torch-compatible SSA value, including other torch.nn.Module's.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$name `,` $value attr-dict `:` type($value)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_MethodOp : Torch_Op<"method", [
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">,
|
||||||
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Define a method of a torch.nn.Module";
|
||||||
|
let description = [{
|
||||||
|
This op declaratively specifies that the parent torch.nn_module has a
|
||||||
|
method `name` which calls `function`. `function` is an unbound function.
|
||||||
|
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
|
||||||
|
"self" object).
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins StrAttr:$name, FlatSymbolRefAttr:$function);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$name `,` $function attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
|
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
#define GET_TYPEDEF_CLASSES
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
|
@ -0,0 +1,118 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_TYPES
|
||||||
|
#define TORCH_TYPES
|
||||||
|
|
||||||
|
include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type defs
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name> {
|
||||||
|
let mnemonic = typeMnemonic;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||||
|
let summary = "torch.nn.Module";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type predicates
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Torch has a fairly advanced and featureful Tensor type, and some of the
|
||||||
|
// semantics are important to preserve in a compilation context. In the future,
|
||||||
|
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
|
||||||
|
// and interop are well served by existing tensor-like types, which are
|
||||||
|
// specifically permitted. Typically, on import, constraints are fairly loose
|
||||||
|
// and based on how the program is captured. Settling on and refining to
|
||||||
|
// specific types is done as part of lowering.
|
||||||
|
//
|
||||||
|
// While lowering it is useful to be able to distinguish between mutable and
|
||||||
|
// immutable tensors:
|
||||||
|
// - Mutable tensors can alias.
|
||||||
|
// - Mutable tensors can be a view over another mutable tensor.
|
||||||
|
// - Mutable tensors act as if reference counted and exist for the lifetime
|
||||||
|
// of any reference or derived view.
|
||||||
|
// Conversely, immutable tensors:
|
||||||
|
// - Are normal SSA values representing the contents of the tensor.
|
||||||
|
// - Cannot alias.
|
||||||
|
// - Cannot be a view of any mutable value.
|
||||||
|
// - Have undefined lifetimes.
|
||||||
|
//
|
||||||
|
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
|
||||||
|
// however, when lowering to specific ops, further constraints are introduced,
|
||||||
|
// necessitating copies, loads, and stores to be inserted to bridge worlds.
|
||||||
|
def AnyTorchImmutableTensor : AnyTypeOf<[
|
||||||
|
// Normal MLIR immutable tensors.
|
||||||
|
AnyTensor,
|
||||||
|
], "allowable torch immutable tensor">;
|
||||||
|
|
||||||
|
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
|
||||||
|
AnyTorchImmutableTensor,
|
||||||
|
Basicpy_NoneType,
|
||||||
|
], "allowable torch immutable tensor (or None)">;
|
||||||
|
|
||||||
|
def AnyTorchMutableTensor : AnyTypeOf<[
|
||||||
|
// "Numpy-style" mutable NDArray. While not offering the full generality
|
||||||
|
// of a Torch tensor, it models the same access patterns and implies the
|
||||||
|
// same aliasing as Torch tensors.
|
||||||
|
Numpy_NdArrayType,
|
||||||
|
], "allowable torch mutable tensor">;
|
||||||
|
|
||||||
|
def AnyTorchTensorType : AnyTypeOf<[
|
||||||
|
AnyTorchImmutableTensor,
|
||||||
|
AnyTorchMutableTensor,
|
||||||
|
], "Any tensor type legal to pass to a Torch kernel">;
|
||||||
|
|
||||||
|
def AnyTorchScalarType : AnyTypeOf<[
|
||||||
|
AnySignedInteger,
|
||||||
|
AnyFloat,
|
||||||
|
Basicpy_BoolType,
|
||||||
|
Basicpy_StrType,
|
||||||
|
Basicpy_NoneType,
|
||||||
|
// Allow signless integers for ease of conversions. In general, this
|
||||||
|
// dialect uses signed integers.
|
||||||
|
AnySignlessInteger,
|
||||||
|
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||||
|
|
||||||
|
def AnyTorchBoolType : AnyTypeOf<[
|
||||||
|
I1,
|
||||||
|
Basicpy_BoolType,
|
||||||
|
], "Any permissible bool type">;
|
||||||
|
|
||||||
|
def AnyTorchBoolListType : AnyTypeOf<[
|
||||||
|
Basicpy_ListType,
|
||||||
|
// TODO: Support typed list when available.
|
||||||
|
], "Any bool list type (bool[])">;
|
||||||
|
|
||||||
|
def AnyTorchIntType : AnyTypeOf<[
|
||||||
|
AnySignedInteger,
|
||||||
|
AnySignlessInteger,
|
||||||
|
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
|
||||||
|
|
||||||
|
def AnyTorchIntListType : AnyTypeOf<[
|
||||||
|
Basicpy_ListType,
|
||||||
|
// TODO: Support typed list when available.
|
||||||
|
], "Any int list type (int[])">;
|
||||||
|
|
||||||
|
def AnyTorchType : AnyTypeOf<[
|
||||||
|
AnyTorchBoolType,
|
||||||
|
AnyTorchScalarType,
|
||||||
|
AnyTorchTensorType,
|
||||||
|
Basicpy_ListType,
|
||||||
|
Basicpy_NoneType,
|
||||||
|
Torch_NnModuleType,
|
||||||
|
], "Any type that is legal to pass to a Torch kernel">;
|
||||||
|
|
||||||
|
#endif // TORCH_TYPES
|
|
@ -16,4 +16,5 @@ add_npcomp_library(NPCOMPCAPI
|
||||||
NPCOMPInitAll
|
NPCOMPInitAll
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
|
NPCOMPTorchDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -133,3 +134,17 @@ int npcompTypeIsATuple(MlirType t) {
|
||||||
MlirType npcompTupleTypeGet(MlirContext context) {
|
MlirType npcompTupleTypeGet(MlirContext context) {
|
||||||
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.nn.Module type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a torch.nn.Module type */
|
||||||
|
int npcompTypeIsANnModule(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::NnModuleType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the singleton torch.nn.Module type. */
|
||||||
|
MlirType npcompNnModuleTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::NnModuleType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ add_npcomp_dialect_library(NPCOMPATenDialect
|
||||||
MLIRATenIncGen
|
MLIRATenIncGen
|
||||||
#MLIRATenEnumsIncGen
|
#MLIRATenEnumsIncGen
|
||||||
MLIRATenOpInterfacesIncGen
|
MLIRATenOpInterfacesIncGen
|
||||||
|
MLIRTorchOpInterfacesIncGen
|
||||||
#MLIRATenToStdIncGen
|
#MLIRATenToStdIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
|
|
@ -31,7 +31,10 @@ OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
void BoolConstantOp::getAsmResultNames(
|
void BoolConstantOp::getAsmResultNames(
|
||||||
function_ref<void(Value, StringRef)> setNameFn) {
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
setNameFn(getResult(), "bool");
|
if (value())
|
||||||
|
setNameFn(getResult(), "bool_true");
|
||||||
|
else
|
||||||
|
setNameFn(getResult(), "bool_false");
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -8,6 +8,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRTorchOpsIncGen
|
MLIRTorchOpsIncGen
|
||||||
MLIRTorchOpInterfacesIncGen
|
MLIRTorchOpInterfacesIncGen
|
||||||
|
MLIRTorchTypesIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
|
@ -7,16 +7,47 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Tablegen Type Definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_TYPEDEF_CLASSES
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
|
|
||||||
void TorchDialect::initialize() {
|
void TorchDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addTypes<
|
||||||
|
#define GET_TYPEDEF_LIST
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
StringRef keyword;
|
||||||
|
if (parser.parseKeyword(&keyword))
|
||||||
|
return Type();
|
||||||
|
if (Type type = generatedTypeParser(getContext(), parser, keyword))
|
||||||
|
return type;
|
||||||
|
|
||||||
|
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
|
||||||
|
<< keyword << "'";
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||||
|
if (failed(generatedTypePrinter(type, printer)))
|
||||||
|
llvm_unreachable("unknown 'torch' type");
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
@ -16,9 +17,7 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
#define GET_OP_CLASSES
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
|
||||||
|
|
||||||
static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) {
|
static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) {
|
||||||
SmallVector<StringRef, 4> strings;
|
SmallVector<StringRef, 4> strings;
|
||||||
|
@ -29,12 +28,12 @@ static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) {
|
||||||
return strings;
|
return strings;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
//===----------------------------------------------------------------------===//
|
||||||
// KernelCall op
|
// KernelCallOp
|
||||||
// -----------------------------------------------------------------------------
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Torch::KernelMetadata Torch::KernelCallOp::getTorchKernelMetadata() {
|
KernelMetadata KernelCallOp::getTorchKernelMetadata() {
|
||||||
return Torch::KernelMetadata{
|
return KernelMetadata{
|
||||||
.kernelName = kernelName(),
|
.kernelName = kernelName(),
|
||||||
.isVararg = sigIsVararg(),
|
.isVararg = sigIsVararg(),
|
||||||
.isVarret = sigIsVarret(),
|
.isVarret = sigIsVarret(),
|
||||||
|
@ -42,3 +41,29 @@ Torch::KernelMetadata Torch::KernelCallOp::getTorchKernelMetadata() {
|
||||||
.returnTypes = strArrayAttrToVector(sigRetTypes()),
|
.returnTypes = strArrayAttrToVector(sigRetTypes()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MethodOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
|
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
|
||||||
|
if (!func)
|
||||||
|
return emitError() << "'" << function()
|
||||||
|
<< "' does not reference a valid function";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NnModuleOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verify(NnModuleOp op) {
|
||||||
|
for (Operation &child : *op.getBody())
|
||||||
|
if (!isa<AttrOp, MethodOp, NnModuleTerminatorOp>(&child))
|
||||||
|
return child.emitOpError() << "is not allowed inside `torch.nn_module`";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
|
|
|
@ -50,7 +50,7 @@ func @numeric_constant_complex_f32() -> complex<f32> {
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @bool_constant
|
// CHECK-LABEL: @bool_constant
|
||||||
func @bool_constant() -> !basicpy.BoolType {
|
func @bool_constant() -> !basicpy.BoolType {
|
||||||
// CHECK: %bool = basicpy.bool_constant true
|
// CHECK: %bool_true = basicpy.bool_constant true
|
||||||
%0 = basicpy.bool_constant true
|
%0 = basicpy.bool_constant true
|
||||||
return %0 : !basicpy.BoolType
|
return %0 : !basicpy.BoolType
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,5 +35,9 @@ func @numeric_constant() {
|
||||||
%2 = basicpy.numeric_constant 2.0 : f32
|
%2 = basicpy.numeric_constant 2.0 : f32
|
||||||
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
|
||||||
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
|
||||||
|
// CHECK: %bool_true = basicpy.bool_constant true
|
||||||
|
%4 = basicpy.bool_constant true
|
||||||
|
// CHECK: %bool_false = basicpy.bool_constant false
|
||||||
|
%5 = basicpy.bool_constant false
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.nn_module {
|
||||||
|
// expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}}
|
||||||
|
func @f()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.nn_module {
|
||||||
|
// expected-error @+1 {{'invalidSym' does not reference a valid function}}
|
||||||
|
torch.method "f", @invalidSym
|
||||||
|
}
|
|
@ -1,9 +1,29 @@
|
||||||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
|
func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: %0 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32>
|
// CHECK: torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32>
|
||||||
%1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>) {
|
%1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>) {
|
||||||
sigArgTypes = [], sigRetTypes = [], sigIsVararg = false, sigIsVarret = false, sigIsMutable = false
|
sigArgTypes = [], sigRetTypes = [], sigIsVararg = false, sigIsVarret = false, sigIsMutable = false
|
||||||
}
|
}
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
%bool_true = basicpy.bool_constant true
|
||||||
|
%num3_i64 = basicpy.numeric_constant 3 : i64
|
||||||
|
%num = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
|
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
|
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
func @f(%arg0: !torch.nn.Module) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
%submodule = torch.nn_module {}
|
||||||
|
|
||||||
|
torch.nn_module {
|
||||||
|
torch.attr "b", %bool_true : !basicpy.BoolType
|
||||||
|
torch.attr "i", %num3_i64 : i64
|
||||||
|
torch.attr "f", %num : f64
|
||||||
|
torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.attr "submodule", %submodule : !torch.nn.Module
|
||||||
|
torch.method "method", @f
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue