Add initial TorchScript module importer

It turns out that this was easiest to structure as a general IValue
importer, since torch module are just one of the possible IValue's.

We import the IValue object graph in a braindead fashion into basicpy
ops and a new `torch.nn_module` op that is used to model the
attributes/methods of a torch::jit::Module IValue. See `Torch/ops.mlir`
for an example, and also check out the .py import tests in
`frontends/pytorch/test/module_import`.

As part of this change, a few housekeeping tasks:
- extract some helpers from graph_importer.cpp
- more helpers around the C API
- misc touchups
pull/155/head
Sean Silva 2021-01-27 16:35:44 -08:00
parent 0f6a65a1c5
commit 689b40c7a6
38 changed files with 1147 additions and 321 deletions

View File

@ -16,7 +16,9 @@ add_library(NPCOMPTorchMLIRExt SHARED
builder/func_builder.cpp builder/func_builder.cpp
builder/graph_importer.cpp builder/graph_importer.cpp
builder/module_builder.cpp builder/module_builder.cpp
builder/ivalue_importer.cpp
builder/python_bindings.cpp builder/python_bindings.cpp
builder/torch_to_mlir_utils.cpp
init_python_bindings.cpp init_python_bindings.cpp
) )

View File

@ -529,80 +529,15 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
} }
MlirValue AcapController::importTensorByValue(at::Tensor tensor) { MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
using at::ScalarType;
auto throwUnsupportedTensorError = [&]() {
std::stringstream msg;
msg << "Unsupported import tensor type: " << tensor;
throw std::invalid_argument(msg.str());
};
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
if (!tensor.is_contiguous())
tensor = tensor.contiguous();
// The flat number of bytes throws an exception for tensors that are not
// dense and accessible as such.
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
c10::Layout::Strided);
// Construct the ShapedType.
auto loc = getCurrentLocation(); auto loc = getCurrentLocation();
MlirType elementType; MlirAttribute valueAttribute = converTensorToMlirElementsAttr(tensor, loc);
if (tensor.scalar_type() == ScalarType::Bool) {
// Bool is a special case. When used as an element type, it must be i1.
// The generalized (non-Tensor) conversion, assumes that Bool is the
// Basicpy bool type.
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
} else {
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
}
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType shapedType = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType, loc);
if (mlirTypeIsNull(shapedType)) {
throwUnsupportedTensorError();
}
// Import DenseElementsAttr data.
// TODO: Support bool tensors.
// TODO: More import formats in C-API.
MlirAttribute valueAttribute;
auto numElements = tensor.numel();
auto tensorData = tensor.data_ptr();
switch (tensor.scalar_type()) {
case ScalarType::Int:
valueAttribute = mlirDenseElementsAttrInt32Get(
shapedType, numElements, static_cast<const int32_t *>(tensorData));
break;
case ScalarType::Long:
valueAttribute = mlirDenseElementsAttrInt64Get(
shapedType, numElements, static_cast<const int64_t *>(tensorData));
break;
case ScalarType::Float:
valueAttribute = mlirDenseElementsAttrFloatGet(
shapedType, numElements, static_cast<const float *>(tensorData));
break;
case ScalarType::Double:
valueAttribute = mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double *>(tensorData));
break;
case ScalarType::Bool:
// TODO: Add a test specifically for bool and ensure consistency between
// storage format and load format
// (https://github.com/llvm/mlir-npcomp/issues/100).
valueAttribute = mlirDenseElementsAttrBoolGet(
shapedType, numElements, static_cast<const int *>(tensorData));
break;
default:
throwUnsupportedTensorError();
}
MlirValue constTensorValue = MlirValue constTensorValue =
funcBuilder->getGeneralConstant(loc, valueAttribute); funcBuilder->getGeneralConstant(loc, valueAttribute);
// Create an array from the tensor constant via the // Create an array from the tensor constant via the
// numpy.create_array_from_tensor op. // numpy.create_array_from_tensor op.
MlirType constArrayType = npcompNdArrayTypeGetFromShaped(shapedType); MlirType constArrayType =
npcompNdArrayTypeGetFromShaped(mlirAttributeGetType(valueAttribute));
MlirOperationState state = mlirOperationStateGet( MlirOperationState state = mlirOperationStateGet(
toMlirStringRef("numpy.create_array_from_tensor"), loc); toMlirStringRef("numpy.create_array_from_tensor"), loc);
mlirOperationStateAddOperands(&state, 1, &constTensorValue); mlirOperationStateAddOperands(&state, 1, &constTensorValue);

View File

@ -87,116 +87,6 @@ void KernelCallBuilder::addResultType(MlirType resultType) {
MlirOperation KernelCallBuilder::create() { return state.createOperation(); } MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
throw std::invalid_argument(message.str());
}
return type;
}
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
mlirEmitError(loc, message.str().c_str());
}
return type;
}
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
using c10::ScalarType;
switch (scalarType) {
case ScalarType::Byte:
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
return mlirIntegerTypeGet(context, 8);
case ScalarType::Char:
return mlirIntegerTypeGet(context, 8);
case ScalarType::Short:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 16);
case ScalarType::Int:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 32);
case ScalarType::Long:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 64);
case ScalarType::Bool:
return npcompBoolTypeGet(context);
case ScalarType::Double:
return mlirF64TypeGet(context);
case ScalarType::Float:
return mlirF32TypeGet(context);
case ScalarType::BFloat16:
return mlirBF16TypeGet(context);
case ScalarType::Half:
return mlirF16TypeGet(context);
default: {
return {nullptr};
}
}
}
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
using c10::TypeKind;
auto kind = torchType->kind();
switch (kind) {
case TypeKind::TensorType: {
auto tensorType = torchType->cast<c10::TensorType>();
// Element type.
MlirType elementType;
if (tensorType->scalarType()) {
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
if (mlirTypeIsNull(elementType))
return {nullptr};
} else {
elementType = npcompAnyDtypeTypeGet(context);
}
// Sizes.
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return npcompNdArrayTypeGetUnranked(elementType);
}
// Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes();
std::vector<int64_t> dims;
dims.resize(*sizes.rank());
for (size_t i = 0; i < dims.size(); ++i) {
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
}
default: {
std::stringstream message;
message << "unable to map Torch type " << torchType << " to MLIR type";
mlirEmitError(loc, message.str().c_str());
return {nullptr};
}
}
}
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
if (!tensor.defined()) {
// Undefined tensors are equivalent to None.
// This may need to be re-evaluated at some point.
return npcompNoneTypeGet(context);
}
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
// TODO: Decide when it is necessary to take strides into account. Right now,
// just erase them and let the compiler decide.
auto sizes = tensor.sizes();
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
}
std::unique_ptr<FuncBuilder> std::unique_ptr<FuncBuilder>
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
MlirLocation location, const std::string &name, MlirLocation location, const std::string &name,

View File

@ -9,6 +9,7 @@
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_FUNC_BUILDER_H #define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_FUNC_BUILDER_H
#include "mlir_utils.h" #include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/IR.h" #include "mlir-c/IR.h"
@ -45,34 +46,6 @@ private:
bool owned = true; bool owned = true;
}; };
/// Maps various runtime types to MlirType.
class TypeMapper {
public:
TypeMapper(MlirContext context) : context(context) {}
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Throws std::invalid_argument on failure.
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
/// Gets a corresponding MlirType for the forward component of a tensor.
/// Throws std::invalid_argument on failure.
MlirType forwardTensorToType(at::Tensor tensor);
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Returns a null type on failure and emits a diagnostic.
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
private:
/// Maps from a scalar type and returns null if no match (no other error
/// reporting).
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
MlirContext context;
};
/// Wraps an MlirBlock under construction, primarily tracking the terminator /// Wraps an MlirBlock under construction, primarily tracking the terminator
/// and supporting manipulation of it. The terminator may be null if it has /// and supporting manipulation of it. The terminator may be null if it has
/// not yet been constructed. /// not yet been constructed.

View File

@ -20,8 +20,10 @@
namespace torch_mlir { namespace torch_mlir {
/// Main entry-point for importing torch::jit::Graph instances (and structures /// Main entry-point for importing torch::jit::Graph instances.
/// surrounding them such as modules and methods). ///
/// This code doesn't handle importing of torch::jit::Module's. See
/// ModuleImporter for that.
/// ///
/// In torch terminology, a Graph is a function. Later in the compiler, we may /// In torch terminology, a Graph is a function. Later in the compiler, we may
/// specialize multiple versions of it. /// specialize multiple versions of it.

View File

@ -0,0 +1,177 @@
//===- ivalue_importer.cpp ------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "ivalue_importer.h"
#include "graph_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
using namespace torch_mlir;
namespace {
/// Helper class for holding state during recursive IValue import.
///
/// The intended usage pattern of this class is to construct it then call
/// `importIValue` exactly once. Calling `importIValue` more than once
/// is likely to produce unexpected results since the same in-memory IValue
/// can be reimported more than once. That is, subsequent calls to
/// `importIValue` will not properly unify IValue's with already-imported
/// IValue's.
///
/// TODO: Support unifying repeated IValue's.
/// This already is an issue when importing a single IValue through the current
/// API, because the same underlying Tensor (or List/Dict) can be referenced by
/// multiple properties of a module. There is an extra complication with tensors
/// because they can alias each other in fairly arbitrary ways, which we will
/// need to model with slice ops.
class IValueImporter {
public:
IValueImporter(MlirBlock importBlock, MlirContext context)
: importBlock(importBlock), context(context), typeMapper(context) {}
MlirValue importIValue(c10::IValue value);
private:
MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody);
MlirBlock importBlock;
MlirContext context;
TypeMapper typeMapper;
};
} // namespace
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
MlirOperation nnModule =
createMlirOperation("torch.nn_module", loc,
npcompNnModuleTypeGet(context), mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
const std::vector<c10::ClassAttribute> &classAttributes =
currentModule.type()->getAttributes();
assert(slots.size() == classAttributes.size() &&
"mismatch between object and type!");
for (int i = 0, e = slots.size(); i < e; i++) {
const c10::ClassAttribute &classAttribute = classAttributes[i];
MlirValue slotValue = importIValue(slots[i]);
// TODO: Is it necessary to track whether an attribute is a "parameter"?
createMlirOperationAtEnd(
nnModuleBody, "torch.attr", loc, slotValue,
toMlirNamedAttribute(
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))));
}
for (torch::jit::Function *function : currentModule.type()->methods()) {
importMethod(function, nnModuleBody);
}
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
return mlirOperationGetResult(nnModule, 0);
}
MlirValue IValueImporter::importIValue(c10::IValue ivalue) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
if (ivalue.isBool()) {
MlirType type = npcompBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bool_constant", loc, type,
toMlirNamedAttribute("value",
mlirBoolAttrGet(context, ivalue.toBool())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isDouble()) {
MlirType type = mlirF64TypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.numeric_constant", loc, type,
toMlirNamedAttribute(
"value", mlirFloatAttrDoubleGet(context, type, ivalue.toDouble())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isInt()) {
MlirType type = mlirIntegerTypeGet(context, 64);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.numeric_constant", loc, type,
toMlirNamedAttribute("value",
mlirIntegerAttrGet(type, ivalue.toInt())));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
MlirOperation constant = createMlirOperationAtEnd(
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirOperation ndarray = createMlirOperationAtEnd(
importBlock, "numpy.create_array_from_tensor", loc,
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
mlirOperationGetResult(constant, 0));
return mlirOperationGetResult(ndarray, 0);
}
if (ivalue.isModule()) {
return importModule(ivalue.toModule());
}
std::stringstream msg;
msg << "Unsupported ivalue: " << ivalue;
throw std::invalid_argument(msg.str());
}
void IValueImporter::importMethod(torch::jit::Function *function,
MlirBlock nnModuleBody) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
FuncBuilder::Inserter inserter = [&](MlirOperation func) {
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), func);
// TODO: This should probably be a flag in MlirMappingOptions.
mlirOperationSetAttributeByName(
func, toMlirStringRef("sym_visibility"),
mlirStringAttrGet(context, toMlirStringRef("private")));
};
// We make an effort for the func op's symbol name to be useful for debugging,
// but still clearly non-load-bearing.
std::string symName =
"__npcomp_priv_fn." + function->qualname().qualifiedName();
GraphImporter::MlirMappingOptions mappingOptions{context, symName, symName,
typeMapper, inserter};
GraphImporter importer(function->graph(), mappingOptions);
importer.initialize();
importer.importGenericFunc();
createMlirOperationAtEnd(
nnModuleBody, "torch.method", loc,
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
context, toMlirStringRef(symName))));
}
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirContext context) {
// When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule())
// ivalue.toModule().dump(true, true, true);
IValueImporter importer(block, context);
importer.importIValue(ivalue);
}

View File

@ -0,0 +1,30 @@
//===- ivalue_importer.h ----------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch IValue's .
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_IVALUE_IMPORTER_H

View File

@ -31,6 +31,52 @@ inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
return mlirNamedAttributeGet(ident, attr); return mlirNamedAttributeGet(ident, attr);
} }
inline void addToMlirOperationState(MlirOperationState &state,
MlirNamedAttribute namedAttr) {
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
}
inline void addToMlirOperationState(MlirOperationState &state,
MlirRegion region) {
mlirOperationStateAddOwnedRegions(&state, 1, &region);
}
inline void addToMlirOperationState(MlirOperationState &state,
MlirValue value) {
mlirOperationStateAddOperands(&state, 1, &value);
}
inline void addToMlirOperationState(MlirOperationState &state,
MlirType resultType) {
mlirOperationStateAddResults(&state, 1, &resultType);
}
template <typename T, typename... Ts>
void addToMlirOperationState(MlirOperationState &state, T &&t, Ts &&...ts) {
addToMlirOperationState(state, t);
addToMlirOperationState(state, std::forward<Ts>(ts)...);
}
inline void addToMlirOperationState(MlirOperationState &state) {}
template <typename... Ts>
MlirOperation createMlirOperation(std::string name, MlirLocation loc,
Ts &&...ts) {
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
addToMlirOperationState(state, std::forward<Ts>(ts)...);
return mlirOperationCreate(&state);
}
template <typename... Ts>
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
MlirLocation loc, Ts &&...ts) {
MlirOperation operation =
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
operation);
return operation;
}
} // namespace torch_mlir } // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H #endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_MLIR_UTILS_H

View File

@ -8,6 +8,7 @@
#include "module_builder.h" #include "module_builder.h"
#include "graph_importer.h" #include "graph_importer.h"
#include "ivalue_importer.h"
#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
@ -109,6 +110,11 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
return function; return function;
} }
void ModuleBuilder::importModule(torch::jit::Module jitModule) {
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
mlirModuleGetContext(module));
}
FuncBuilder::Inserter ModuleBuilder::createInserter() { FuncBuilder::Inserter ModuleBuilder::createInserter() {
MlirBlock block = getBodyBlock(); MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator; MlirOperation terminator = this->terminator;
@ -129,5 +135,6 @@ void ModuleBuilder::bind(py::module &m) {
.def_property_readonly("module", &ModuleBuilder::getModuleObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("capture_function", &ModuleBuilder::startCaptureFunction, .def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>()) py::keep_alive<0, 1>())
.def("import_function", &ModuleBuilder::importFunction); .def("import_function", &ModuleBuilder::importFunction)
.def("import_module", &ModuleBuilder::importModule);
} }

View File

@ -16,6 +16,7 @@
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <torch/csrc/jit/api/compilation_unit.h> #include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir { namespace torch_mlir {
@ -43,6 +44,9 @@ public:
torch::jit::StrongFunctionPtr torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function); importFunction(torch::jit::StrongFunctionPtr function);
// Imports a torch::jit::Module into the current module.
void importModule(torch::jit::Module jitModule);
private: private:
FuncBuilder::Inserter createInserter(); FuncBuilder::Inserter createInserter();
MlirBlock getBodyBlock(); MlirBlock getBodyBlock();

View File

@ -0,0 +1,212 @@
//===- ivalue_importer.cpp ------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include "ivalue_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
using namespace torch_mlir;
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
throw std::invalid_argument(message.str());
}
return type;
}
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
mlirEmitError(loc, message.str().c_str());
}
return type;
}
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
using c10::ScalarType;
switch (scalarType) {
case ScalarType::Byte:
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
return mlirIntegerTypeGet(context, 8);
case ScalarType::Char:
return mlirIntegerTypeGet(context, 8);
case ScalarType::Short:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 16);
case ScalarType::Int:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 32);
case ScalarType::Long:
// TODO: convert to mlirIntegerTypeSignedGet once supported.
return mlirIntegerTypeGet(context, 64);
case ScalarType::Bool:
return npcompBoolTypeGet(context);
case ScalarType::Double:
return mlirF64TypeGet(context);
case ScalarType::Float:
return mlirF32TypeGet(context);
case ScalarType::BFloat16:
return mlirBF16TypeGet(context);
case ScalarType::Half:
return mlirF16TypeGet(context);
default: {
return {nullptr};
}
}
}
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
using c10::TypeKind;
auto kind = torchType->kind();
switch (kind) {
case TypeKind::TensorType: {
auto tensorType = torchType->cast<c10::TensorType>();
// Element type.
MlirType elementType;
if (tensorType->scalarType()) {
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
if (mlirTypeIsNull(elementType))
return {nullptr};
} else {
elementType = npcompAnyDtypeTypeGet(context);
}
// Sizes.
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return npcompNdArrayTypeGetUnranked(elementType);
}
// Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes();
std::vector<int64_t> dims;
dims.resize(*sizes.rank());
for (size_t i = 0; i < dims.size(); ++i) {
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
}
case TypeKind::ClassType: {
return npcompNnModuleTypeGet(context);
}
case TypeKind::FloatType: {
return mlirF64TypeGet(context);
}
case TypeKind::IntType: {
return mlirIntegerTypeGet(context, 64);
}
default: {
std::stringstream message;
message << "unable to map Torch type " << *torchType << " to MLIR type";
mlirEmitError(loc, message.str().c_str());
return {nullptr};
}
}
}
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
if (!tensor.defined()) {
// Undefined tensors are equivalent to None.
// This may need to be re-evaluated at some point.
return npcompNoneTypeGet(context);
}
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
// TODO: Decide when it is necessary to take strides into account. Right now,
// just erase them and let the compiler decide.
auto sizes = tensor.sizes();
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
}
MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
MlirLocation loc) {
MlirContext context = mlirLocationGetContext(loc);
TypeMapper typeMapper(context);
using at::ScalarType;
auto throwUnsupportedTensorError = [&]() {
std::stringstream msg;
msg << "Unsupported import tensor type: " << tensor;
throw std::invalid_argument(msg.str());
};
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
if (!tensor.is_contiguous())
tensor = tensor.contiguous();
// The flat number of bytes throws an exception for tensors that are not
// dense and accessible as such.
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
c10::Layout::Strided);
// Construct the ShapedType.
MlirType elementType;
if (tensor.scalar_type() == ScalarType::Bool) {
// Bool is a special case. When used as an element type, it must be i1.
// The generalized (non-Tensor) conversion, assumes that Bool is the
// Basicpy bool type.
elementType = mlirIntegerTypeGet(context, 1);
} else {
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
}
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType shapedType = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType, loc);
if (mlirTypeIsNull(shapedType)) {
throwUnsupportedTensorError();
}
// Import DenseElementsAttr data.
// TODO: Support bool tensors.
// TODO: More import formats in C-API.
auto numElements = tensor.numel();
auto tensorData = tensor.data_ptr();
switch (tensor.scalar_type()) {
case ScalarType::Int:
return mlirDenseElementsAttrInt32Get(
shapedType, numElements, static_cast<const int32_t *>(tensorData));
break;
case ScalarType::Long:
return mlirDenseElementsAttrInt64Get(
shapedType, numElements, static_cast<const int64_t *>(tensorData));
break;
case ScalarType::Float:
return mlirDenseElementsAttrFloatGet(
shapedType, numElements, static_cast<const float *>(tensorData));
break;
case ScalarType::Double:
return mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double *>(tensorData));
break;
case ScalarType::Bool:
// TODO: Add a test specifically for bool and ensure consistency between
// storage format and load format
// (https://github.com/llvm/mlir-npcomp/issues/100).
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
static_cast<const int *>(tensorData));
break;
default:
throwUnsupportedTensorError();
}
return {nullptr}; // Unreachable.
}

View File

@ -0,0 +1,60 @@
//===- torch_to_mlir_utils.h ------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
#include <memory>
#include "../pybind.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Maps various runtime types to MlirType.
class TypeMapper {
public:
TypeMapper(MlirContext context) : context(context) {}
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Throws std::invalid_argument on failure.
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
/// Gets a corresponding MlirType for the forward component of a tensor.
/// Throws std::invalid_argument on failure.
MlirType forwardTensorToType(at::Tensor tensor);
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Torch ScalarType is used to represent the possible element types of Torch
/// tensors, which is different from the set of types used to represent
/// Python numeric scalar values (which are always either f64 or i64).
/// Returns a null type on failure and emits a diagnostic.
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
private:
/// Maps from a scalar type and returns null if no match (no other error
/// reporting).
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
MlirContext context;
};
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
MlirLocation loc);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H

View File

@ -2,26 +2,21 @@
# This file is licensed under a pytorch-style license # This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information. # See frontends/pytorch/LICENSE for license information.
import typing
import torch import torch
import torch_mlir import torch_mlir
# RUN: %PYTHON %s # RUN: %PYTHON %s
@torch.jit.script
class ExampleClass:
def __init__(self, x):
self.x = x
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# For now, TorchScript classes are wholly unsupported, so use it to test # To test errors, use a type that we don't support yet.
# type conversion errors.
try: try:
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def import_class(c: ExampleClass): def import_class(x: typing.Any):
return c.x return x
except RuntimeError as e: except RuntimeError as e:
# TODO: Once diagnostics are enabled, verify the actual error emitted. # TODO: Once diagnostics are enabled, verify the actual error emitted.
assert str(e) == "could not convert function input type" assert str(e) == "could not convert function input type"

View File

@ -0,0 +1,43 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class Submodule(torch.nn.Module):
def forward(self, x):
return x
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = Submodule()
def forward(self, x, y):
return x * y
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
# However, we do make an attempt to ensure that the names are debuggable.
#
# The names have the following structure:
# - `__npcomp_priv_fn`: marker that this symbol name is private to npcomp
# - `__torch__.Submodule.forward`: the name that TorchScript gives the function
# - For those curious, the `__torch__` would be the Python module name, but in
# the case that the name is `__main__` Torch replaces it with `__torch__` to
# avoid collisions.
# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward
# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -0,0 +1,26 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# CHECK-LABEL: torch.nn_module
# CHECK: loc("{{.*}}methods-locations.py":[[@LINE+1]]
return x * y
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print(enable_debug_info=True)

View File

@ -0,0 +1,39 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x * y
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
# CHECK-LABEL: func private
# CHECK-SAME: @[[SYMNAME:.*]](
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.method "forward", @[[SYMNAME]]
# CHECK: }
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -0,0 +1,34 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.i = 3
self.f = 42.5
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
# CHECK: %[[MODULE:.*]] = torch.nn_module {
# Note: for some reason, Torch always adds a "training" property to all modules.
# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType
# CHECK: torch.attr "i", %[[N3]] : i64
# CHECK: torch.attr "f", %[[N42]] : f64
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -0,0 +1,52 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class Submodule(torch.nn.Module):
def __init__(self, n):
super().__init__()
self.n = n
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s0 = Submodule(0)
self.s1 = Submodule(1)
# CHECK: %[[T:.*]] = basicpy.bool_constant true
# CHECK: %[[T1:.*]] = basicpy.bool_constant true
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
# CHECK: %[[S0:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T1]] : !basicpy.BoolType
# CHECK: torch.attr "n", %[[N0]] : i64
# CHECK: }
# CHECK: %[[T2:.*]] = basicpy.bool_constant true
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[S1:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T2]] : !basicpy.BoolType
# CHECK: torch.attr "n", %[[N1]] : i64
# CHECK: }
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module
# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module
# CHECK: }
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -0,0 +1,35 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# TODO: Test (and make work) tensors that alias each other.
self.t = torch.ones(1)
self.p = torch.nn.Parameter(torch.arange(3.0))
# CHECK: %[[CP:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>
# CHECK: %[[P:.*]] = numpy.create_array_from_tensor %[[CP]] : (tensor<3xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: }
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -107,6 +107,16 @@ int npcompTypeIsATuple(MlirType t);
/** Gets the generic Python "tuple" type. */ /** Gets the generic Python "tuple" type. */
MlirType npcompTupleTypeGet(MlirContext context); MlirType npcompTupleTypeGet(MlirContext context);
/*============================================================================*/
/* torch.nn.Module type. */
/*============================================================================*/
/** Checks whether the given type is a torch.nn.Module type */
int npcompTypeIsANnModule(MlirType t);
/** Gets the singleton torch.nn.Module type. */
MlirType npcompNnModuleTypeGet(MlirContext context);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -12,7 +12,7 @@
include "npcomp/Dialect/ATen/IR/ATenDialect.td" include "npcomp/Dialect/ATen/IR/ATenDialect.td"
include "npcomp/Dialect/ATen/IR/ATenOpInterface.td" include "npcomp/Dialect/ATen/IR/ATenOpInterface.td"
include "npcomp/Dialect/Torch/IR/OpInterfaces.td" include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
include "npcomp/Dialect/Torch/IR/TorchBase.td" include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"

View File

@ -5,6 +5,11 @@ mlir_tablegen(TorchDialect.h.inc -gen-dialect-decls -dialect=torch)
add_public_tablegen_target(MLIRTorchOpsIncGen) add_public_tablegen_target(MLIRTorchOpsIncGen)
add_dependencies(mlir-headers MLIRTorchOpsIncGen) add_dependencies(mlir-headers MLIRTorchOpsIncGen)
set(LLVM_TARGET_DEFINITIONS TorchTypes.td)
mlir_tablegen(TorchTypes.h.inc -gen-typedef-decls)
mlir_tablegen(TorchTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTorchTypesIncGen)
set(LLVM_TARGET_DEFINITIONS OpInterfaces.td) set(LLVM_TARGET_DEFINITIONS OpInterfaces.td)
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)

View File

@ -26,7 +26,7 @@ def Torch_Dialect : Dialect {
- Transitions between mutable and immutable tensors. - Transitions between mutable and immutable tensors.
- Gradient associations and management. - Gradient associations and management.
- Custom ops. - Custom ops.
- Types specific to PyTorch. - Types specific to PyTorch such as torch.nn.Module structures
- Module level constructs like quantization parameters, globals, etc. - Module level constructs like quantization parameters, globals, etc.
Where possible, this dialect composes with types and ops from the `Numpy` Where possible, this dialect composes with types and ops from the `Numpy`
@ -40,92 +40,4 @@ def Torch_Dialect : Dialect {
}]; }];
} }
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
// Torch has a fairly advanced and featureful Tensor type, and some of the
// semantics are important to preserve in a compilation context. In the future,
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
// and interop are well served by existing tensor-like types, which are
// specifically permitted. Typically, on import, constraints are fairly loose
// and based on how the program is captured. Settling on and refining to
// specific types is done as part of lowering.
//
// While lowering it is useful to be able to distinguish between mutable and
// immutable tensors:
// - Mutable tensors can alias.
// - Mutable tensors can be a view over another mutable tensor.
// - Mutable tensors act as if reference counted and exist for the lifetime
// of any reference or derived view.
// Conversely, immutable tensors:
// - Are normal SSA values representing the contents of the tensor.
// - Cannot alias.
// - Cannot be a view of any mutable value.
// - Have undefined lifetimes.
//
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
// however, when lowering to specific ops, further constraints are introduced,
// necessitating copies, loads, and stores to be inserted to bridge worlds.
def AnyTorchImmutableTensor : AnyTypeOf<[
// Normal MLIR immutable tensors.
AnyTensor,
], "allowable torch immutable tensor">;
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
AnyTorchImmutableTensor,
Basicpy_NoneType,
], "allowable torch immutable tensor (or None)">;
def AnyTorchMutableTensor : AnyTypeOf<[
// "Numpy-style" mutable NDArray. While not offering the full generality
// of a Torch tensor, it models the same access patterns and implies the
// same aliasing as Torch tensors.
Numpy_NdArrayType,
], "allowable torch mutable tensor">;
def AnyTorchTensorType : AnyTypeOf<[
AnyTorchImmutableTensor,
AnyTorchMutableTensor,
], "Any tensor type legal to pass to a Torch kernel">;
def AnyTorchScalarType : AnyTypeOf<[
AnySignedInteger,
AnyFloat,
Basicpy_BoolType,
Basicpy_StrType,
Basicpy_NoneType,
// Allow signless integers for ease of conversions. In general, this
// dialect uses signed integers.
AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchBoolType : AnyTypeOf<[
I1,
Basicpy_BoolType,
], "Any permissible bool type">;
def AnyTorchBoolListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any bool list type (bool[])">;
def AnyTorchIntType : AnyTypeOf<[
AnySignedInteger,
AnySignlessInteger,
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
def AnyTorchIntListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any int list type (int[])">;
def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_ListType,
Basicpy_NoneType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_BASE #endif // TORCH_BASE

View File

@ -12,7 +12,9 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc" #include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"

View File

@ -9,8 +9,9 @@
#ifndef TORCH_OPS #ifndef TORCH_OPS
#define TORCH_OPS #define TORCH_OPS
include "npcomp/Dialect/Torch/IR/TorchBase.td" include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Dialect/Torch/IR/OpInterfaces.td" include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []> class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> { : Op<Torch_Dialect, mnemonic, traits> {
@ -47,4 +48,87 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [
}]; }];
} }
//===----------------------------------------------------------------------===//
// TorchScript modeling ops.
//===----------------------------------------------------------------------===//
def Torch_NnModuleOp : Torch_Op<"nn_module", [
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
let summary = "Constructs a torch.nn.Module";
let description = [{
This op is used to represent a torch.nn.Module when importing a
graph of Python objects.
This op returns a new torch.nn.Module as an SSA value, with a set of
declaratively specified properties.
Example:
```mlir
%2 = torch.nn_module {
torch.attr "b", %bool_true : !basicpy.BoolType
torch.attr "i", %num3_i64 : i64
torch.attr "f", %num : f64
torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule", %1 : !torch.nn.Module
torch.method "method", @f
}
```
}];
let arguments = (ins);
let results = (outs Torch_NnModuleType:$result);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let assemblyFormat = "$region attr-dict";
}
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
let summary = "Implicit terminator for torch.nn_module";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}
def Torch_AttrOp : Torch_Op<"attr", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
let summary = "Define an attribute of a torch.nn.Module";
let description = [{
This op declaratively specifies that the parent torch.nn_module has an
attribute `name` with value `value`, which is allowed to be an arbitrary
Torch-compatible SSA value, including other torch.nn.Module's.
}];
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
let results = (outs);
let assemblyFormat = [{
$name `,` $value attr-dict `:` type($value)
}];
}
def Torch_MethodOp : Torch_Op<"method", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Define a method of a torch.nn.Module";
let description = [{
This op declaratively specifies that the parent torch.nn_module has a
method `name` which calls `function`. `function` is an unbound function.
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
"self" object).
}];
let arguments = (ins StrAttr:$name, FlatSymbolRefAttr:$function);
let results = (outs);
let assemblyFormat = [{
$name `,` $function attr-dict
}];
}
#endif // TORCH_OPS #endif // TORCH_OPS

View File

@ -0,0 +1,17 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H

View File

@ -0,0 +1,118 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_TYPES
#define TORCH_TYPES
include "npcomp/Dialect/Torch/IR/TorchBase.td"
//===----------------------------------------------------------------------===//
// Type defs
//===----------------------------------------------------------------------===//
class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name> {
let mnemonic = typeMnemonic;
}
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
let summary = "torch.nn.Module";
let description = [{
}];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
// Torch has a fairly advanced and featureful Tensor type, and some of the
// semantics are important to preserve in a compilation context. In the future,
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
// and interop are well served by existing tensor-like types, which are
// specifically permitted. Typically, on import, constraints are fairly loose
// and based on how the program is captured. Settling on and refining to
// specific types is done as part of lowering.
//
// While lowering it is useful to be able to distinguish between mutable and
// immutable tensors:
// - Mutable tensors can alias.
// - Mutable tensors can be a view over another mutable tensor.
// - Mutable tensors act as if reference counted and exist for the lifetime
// of any reference or derived view.
// Conversely, immutable tensors:
// - Are normal SSA values representing the contents of the tensor.
// - Cannot alias.
// - Cannot be a view of any mutable value.
// - Have undefined lifetimes.
//
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
// however, when lowering to specific ops, further constraints are introduced,
// necessitating copies, loads, and stores to be inserted to bridge worlds.
def AnyTorchImmutableTensor : AnyTypeOf<[
// Normal MLIR immutable tensors.
AnyTensor,
], "allowable torch immutable tensor">;
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
AnyTorchImmutableTensor,
Basicpy_NoneType,
], "allowable torch immutable tensor (or None)">;
def AnyTorchMutableTensor : AnyTypeOf<[
// "Numpy-style" mutable NDArray. While not offering the full generality
// of a Torch tensor, it models the same access patterns and implies the
// same aliasing as Torch tensors.
Numpy_NdArrayType,
], "allowable torch mutable tensor">;
def AnyTorchTensorType : AnyTypeOf<[
AnyTorchImmutableTensor,
AnyTorchMutableTensor,
], "Any tensor type legal to pass to a Torch kernel">;
def AnyTorchScalarType : AnyTypeOf<[
AnySignedInteger,
AnyFloat,
Basicpy_BoolType,
Basicpy_StrType,
Basicpy_NoneType,
// Allow signless integers for ease of conversions. In general, this
// dialect uses signed integers.
AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchBoolType : AnyTypeOf<[
I1,
Basicpy_BoolType,
], "Any permissible bool type">;
def AnyTorchBoolListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any bool list type (bool[])">;
def AnyTorchIntType : AnyTypeOf<[
AnySignedInteger,
AnySignlessInteger,
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
def AnyTorchIntListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any int list type (int[])">;
def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_ListType,
Basicpy_NoneType,
Torch_NnModuleType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_TYPES

View File

@ -16,4 +16,5 @@ add_npcomp_library(NPCOMPCAPI
NPCOMPInitAll NPCOMPInitAll
NPCOMPBasicpyDialect NPCOMPBasicpyDialect
NPCOMPNumpyDialect NPCOMPNumpyDialect
NPCOMPTorchDialect
) )

View File

@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -133,3 +134,17 @@ int npcompTypeIsATuple(MlirType t) {
MlirType npcompTupleTypeGet(MlirContext context) { MlirType npcompTupleTypeGet(MlirContext context) {
return wrap(Basicpy::TupleType::get(unwrap(context))); return wrap(Basicpy::TupleType::get(unwrap(context)));
} }
/*============================================================================*/
/* torch.nn.Module type. */
/*============================================================================*/
/** Checks whether the given type is a torch.nn.Module type */
int npcompTypeIsANnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>();
}
/** Gets the singleton torch.nn.Module type. */
MlirType npcompNnModuleTypeGet(MlirContext context) {
return wrap(Torch::NnModuleType::get(unwrap(context)));
}

View File

@ -9,6 +9,7 @@ add_npcomp_dialect_library(NPCOMPATenDialect
MLIRATenIncGen MLIRATenIncGen
#MLIRATenEnumsIncGen #MLIRATenEnumsIncGen
MLIRATenOpInterfacesIncGen MLIRATenOpInterfacesIncGen
MLIRTorchOpInterfacesIncGen
#MLIRATenToStdIncGen #MLIRATenToStdIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC

View File

@ -31,7 +31,10 @@ OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
void BoolConstantOp::getAsmResultNames( void BoolConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "bool"); if (value())
setNameFn(getResult(), "bool_true");
else
setNameFn(getResult(), "bool_false");
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -8,6 +8,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
DEPENDS DEPENDS
MLIRTorchOpsIncGen MLIRTorchOpsIncGen
MLIRTorchOpInterfacesIncGen MLIRTorchOpInterfacesIncGen
MLIRTorchTypesIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core

View File

@ -7,16 +7,47 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP::Torch; using namespace mlir::NPCOMP::Torch;
//===----------------------------------------------------------------------===//
// Tablegen Type Definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
void TorchDialect::initialize() { void TorchDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc" #include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
>(); >();
addTypes<
#define GET_TYPEDEF_LIST
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
>();
}
Type TorchDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (Type type = generatedTypeParser(getContext(), parser, keyword))
return type;
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
<< keyword << "'";
return Type();
}
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
if (failed(generatedTypePrinter(type, printer)))
llvm_unreachable("unknown 'torch' type");
} }
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h"

View File

@ -9,6 +9,7 @@
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
@ -16,9 +17,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) { static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) {
SmallVector<StringRef, 4> strings; SmallVector<StringRef, 4> strings;
@ -29,12 +28,12 @@ static SmallVector<StringRef, 4> strArrayAttrToVector(ArrayAttr array) {
return strings; return strings;
} }
// ----------------------------------------------------------------------------- //===----------------------------------------------------------------------===//
// KernelCall op // KernelCallOp
// ----------------------------------------------------------------------------- //===----------------------------------------------------------------------===//
Torch::KernelMetadata Torch::KernelCallOp::getTorchKernelMetadata() { KernelMetadata KernelCallOp::getTorchKernelMetadata() {
return Torch::KernelMetadata{ return KernelMetadata{
.kernelName = kernelName(), .kernelName = kernelName(),
.isVararg = sigIsVararg(), .isVararg = sigIsVararg(),
.isVarret = sigIsVarret(), .isVarret = sigIsVarret(),
@ -42,3 +41,29 @@ Torch::KernelMetadata Torch::KernelCallOp::getTorchKernelMetadata() {
.returnTypes = strArrayAttrToVector(sigRetTypes()), .returnTypes = strArrayAttrToVector(sigRetTypes()),
}; };
} }
//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
if (!func)
return emitError() << "'" << function()
<< "' does not reference a valid function";
return success();
}
//===----------------------------------------------------------------------===//
// NnModuleOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(NnModuleOp op) {
for (Operation &child : *op.getBody())
if (!isa<AttrOp, MethodOp, NnModuleTerminatorOp>(&child))
return child.emitOpError() << "is not allowed inside `torch.nn_module`";
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -50,7 +50,7 @@ func @numeric_constant_complex_f32() -> complex<f32> {
// ----- // -----
// CHECK-LABEL: @bool_constant // CHECK-LABEL: @bool_constant
func @bool_constant() -> !basicpy.BoolType { func @bool_constant() -> !basicpy.BoolType {
// CHECK: %bool = basicpy.bool_constant true // CHECK: %bool_true = basicpy.bool_constant true
%0 = basicpy.bool_constant true %0 = basicpy.bool_constant true
return %0 : !basicpy.BoolType return %0 : !basicpy.BoolType
} }

View File

@ -35,5 +35,9 @@ func @numeric_constant() {
%2 = basicpy.numeric_constant 2.0 : f32 %2 = basicpy.numeric_constant 2.0 : f32
// CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32> // CHECK: %num_0 = basicpy.numeric_constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
%3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32> %3 = basicpy.numeric_constant [2.0 : f32, 3.0 : f32] : complex<f32>
// CHECK: %bool_true = basicpy.bool_constant true
%4 = basicpy.bool_constant true
// CHECK: %bool_false = basicpy.bool_constant false
%5 = basicpy.bool_constant false
return return
} }

View File

@ -0,0 +1,15 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
// -----
torch.nn_module {
// expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}}
func @f()
}
// -----
torch.nn_module {
// expected-error @+1 {{'invalidSym' does not reference a valid function}}
torch.method "f", @invalidSym
}

View File

@ -1,9 +1,29 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s // RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> { func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
// CHECK: %0 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32> // CHECK: torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32>
%1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>) { %1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>) {
sigArgTypes = [], sigRetTypes = [], sigIsVararg = false, sigIsVarret = false, sigIsMutable = false sigArgTypes = [], sigRetTypes = [], sigIsVararg = false, sigIsVarret = false, sigIsMutable = false
} }
return %1 : tensor<*xf32> return %1 : tensor<*xf32>
} }
%bool_true = basicpy.bool_constant true
%num3_i64 = basicpy.numeric_constant 3 : i64
%num = basicpy.numeric_constant 4.250000e+01 : f64
%cst = constant dense<1.000000e+00> : tensor<1xf32>
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
func @f(%arg0: !torch.nn.Module) {
return
}
%submodule = torch.nn_module {}
torch.nn_module {
torch.attr "b", %bool_true : !basicpy.BoolType
torch.attr "i", %num3_i64 : i64
torch.attr "f", %num : f64
torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule", %submodule : !torch.nn.Module
torch.method "method", @f
}