Remove TypeMapper

pull/309/head
Sean Silva 2021-09-17 04:12:19 +00:00
parent 6f710bbc47
commit 3e3459690c
4 changed files with 56 additions and 108 deletions

View File

@ -100,8 +100,7 @@ class IValueImporter {
public:
IValueImporter(MlirBlock importBlock, MlirContext context,
ClassAnnotator &annotator)
: importBlock(importBlock), context(context), typeMapper(context),
annotator(annotator) {}
: importBlock(importBlock), context(context), annotator(annotator) {}
MlirValue importIValue(c10::IValue ivalue);
@ -116,7 +115,6 @@ private:
MlirBlock importBlock;
MlirContext context;
TypeMapper typeMapper;
ClassAnnotator &annotator;
// Map tracking already-imported values.
@ -275,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc,
torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, list.elementType())),
getMlirTypeFromTorchType(loc, list.elementType())),
elems);
return mlirOperationGetResult(operation, 0);
}
@ -290,8 +288,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc,
torchMlirTorchDictTypeGet(
typeMapper.mapFromTorchType(loc, dict.keyType()),
typeMapper.mapFromTorchType(loc, dict.valueType())),
getMlirTypeFromTorchType(loc, dict.keyType()),
getMlirTypeFromTorchType(loc, dict.valueType())),
keys, values);
return mlirOperationGetResult(operation, 0);
}
@ -385,7 +383,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
if (tensor.qscheme() == c10::kPerTensorAffine) {
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
@ -461,9 +459,8 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
toMlirNamedAttribute(
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type",
mlirTypeAttrGet(typeMapper.mapFromTorchType(
loc, classAttribute.getType()))),
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
loc, classAttribute.getType()))),
isPrivate);
}
@ -523,7 +520,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
}
std::vector<int64_t> shape = *maybeShape;
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
MlirType dtype = getMlirTypeForTorchScalarType(
mlirLocationUnknownGet(context), *maybeDtype);
MlirType typeBound;
// `std::vector`'s `.data()` method can return nullptr when the

View File

@ -69,7 +69,6 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
}
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
TypeMapper typeMapper(context);
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();
@ -142,13 +141,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
} else if (output->type()->cast<c10::IntType>()) {
op = createMlirOperation(
"torch.constant.int", loc,
typeMapper.mapFromTorchType(loc, output->type()),
getMlirTypeFromTorchType(loc, output->type()),
toMlirNamedAttribute("value",
importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::FloatType>()) {
op = createMlirOperation(
"torch.constant.float", loc,
typeMapper.mapFromTorchType(loc, output->type()),
getMlirTypeFromTorchType(loc, output->type()),
toMlirNamedAttribute("value",
importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) {
@ -233,7 +232,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
torch::jit::Function *function = classType->findMethod(methodName);
torch::jit::Block *calleeEntryBlock = function->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
return getMlirTypeFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc,
@ -251,7 +250,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
torch::jit::Block *calleeEntryBlock =
functionType->function()->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
return getMlirTypeFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "std.call_indirect", loc,

View File

@ -1,4 +1,4 @@
//===- ivalue_importer.cpp ------------------------------------------------===//
//===- torch_to_mlir_utils.cpp --------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
@ -20,28 +20,8 @@
using namespace torch_mlir;
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
throw std::invalid_argument(message.str());
}
return type;
}
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType) {
auto type = rawMapFromTorchScalarType(scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
mlirEmitError(loc, message.str().c_str());
}
return type;
}
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
c10::ScalarType scalarType) {
using c10::ScalarType;
switch (scalarType) {
case ScalarType::Byte:
@ -74,6 +54,18 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
}
}
MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType) {
auto type =
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
if (mlirTypeIsNull(type)) {
std::stringstream message;
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
mlirEmitError(loc, message.str().c_str());
}
return type;
}
// Types (such as `LinearPackedParamsBase`) implemented with the
// `torch::CustomClassHolder` mechanism described at
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
@ -119,8 +111,9 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
throw mlir_diagnostic_emitted();
}
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
MlirContext context = mlirLocationGetContext(loc);
using c10::TypeKind;
auto kind = torchType->kind();
switch (kind) {
@ -129,7 +122,8 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
// Element type.
MlirType elementType = {nullptr};
if (tensorType->scalarType()) {
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
elementType =
getMlirTypeForTorchScalarType(loc, *tensorType->scalarType());
if (mlirTypeIsNull(elementType))
return {nullptr};
}
@ -172,27 +166,27 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return torchMlirTorchStringTypeGet(context);
}
case TypeKind::OptionalType: {
return torchMlirTorchOptionalTypeGet(mapFromTorchType(
return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::TupleType: {
std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type :
torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(mapFromTorchType(loc, type));
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
}
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::ListType: {
return torchMlirTorchListTypeGet(mapFromTorchType(
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::DictType: {
auto dictType = torchType->cast<c10::DictType>();
return torchMlirTorchDictTypeGet(
mapFromTorchType(loc, dictType->getKeyType()),
mapFromTorchType(loc, dictType->getValueType()));
getMlirTypeFromTorchType(loc, dictType->getKeyType()),
getMlirTypeFromTorchType(loc, dictType->getValueType()));
}
case TypeKind::NoneType: {
return torchMlirTorchNoneTypeGet(context);
@ -223,29 +217,12 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
}
}
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
if (!tensor.defined()) {
// Undefined tensors are equivalent to None.
// This may need to be re-evaluated at some point.
return torchMlirTorchNoneTypeGet(context);
}
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
// TODO: Decide when it is necessary to take strides into account. Right now,
// just erase them and let the compiler decide.
auto sizes = tensor.sizes();
return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
sizes.data(), elementType);
}
MlirType
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema) {
MlirLocation loc = mlirLocationUnknownGet(context);
TypeMapper typeMapper(context);
auto mapType = [&](const c10::TypePtr &torchType) {
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
MlirType type = getMlirTypeFromTorchType(loc, torchType);
if (mlirTypeIsNull(type)) {
std::stringstream msg;
msg << "unsupported type in function schema: '"
@ -267,8 +244,6 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
MlirLocation loc) {
MlirContext context = mlirLocationGetContext(loc);
TypeMapper typeMapper(context);
using at::ScalarType;
auto throwUnsupportedTensorError = [&]() {
@ -292,8 +267,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
// quantized types it might differ (e.g. QInt8 becomes Char). Caller code is
// responsible for materializing the proper op that incorporates the
// quantization scheme to create a tensor of e.g. `!torch.qint8` element type.
MlirType elementType = typeMapper.mapFromTorchScalarType(
c10::toUnderlying(tensor.scalar_type()));
MlirType elementType = getMlirTypeForTorchScalarType(
loc, c10::toUnderlying(tensor.scalar_type()));
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType shapedType = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(), elementType, {nullptr});
@ -378,10 +353,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
std::vector<MlirType>
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values) {
TypeMapper typeMapper(mlirLocationGetContext(loc));
std::vector<MlirType> ret;
for (auto value : values) {
MlirType t = typeMapper.mapFromTorchType(loc, value->type());
MlirType t = getMlirTypeFromTorchType(loc, value->type());
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("unsupported type");
ret.push_back(t);

View File

@ -20,44 +20,22 @@
namespace torch_mlir {
/// Maps various runtime types to MlirType.
class TypeMapper {
public:
TypeMapper(MlirContext context) : context(context) {}
/// Gets a corresponding MlirType for the Torch ScalarType.
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
/// type universe from the Python-based types modeled with `c10::Type`.
/// Compared to the Python types, which just have `int` and `float` and
/// `bool` numeric types, ScalarType has detailed bit-width and precision
/// considerations (which matter a lot for tensors, but don't really matter
/// for Python code).
///
/// Returns a null type on failure and emits a diagnostic.
MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
c10::ScalarType scalarType);
/// Gets a corresponding MlirType for the Torch ScalarType.
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
/// type universe from the Python-based types modeled with `c10::Type`.
/// Compared to the Python types, which just have `int` and `float` and
/// `bool` numeric types, ScalarType has detailed bit-width and precision
/// considerations (which matter a lot for tensors, but don't really matter
/// for Python code).
///
/// Throws std::invalid_argument on failure.
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
/// Gets a corresponding MlirType for the forward component of a tensor.
/// Throws std::invalid_argument on failure.
MlirType forwardTensorToType(at::Tensor tensor);
/// Gets a corresponding MlirType for the Torch ScalarType.
/// Torch ScalarType is used to represent the possible element types of Torch
/// tensors, which is different from the set of types used to represent
/// Python numeric scalar values (which are always either f64 or !torch.int).
///
/// Returns a null type on failure and emits a diagnostic.
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
private:
/// Maps from a scalar type and returns null if no match (no other error
/// reporting).
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
MlirContext context;
};
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType getMlirTypeFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType);
/// Creates a FunctionType suitable for expressing the signature of `schema`.
///