mirror of https://github.com/llvm/torch-mlir
Remove TypeMapper
parent
6f710bbc47
commit
3e3459690c
|
@ -100,8 +100,7 @@ class IValueImporter {
|
|||
public:
|
||||
IValueImporter(MlirBlock importBlock, MlirContext context,
|
||||
ClassAnnotator &annotator)
|
||||
: importBlock(importBlock), context(context), typeMapper(context),
|
||||
annotator(annotator) {}
|
||||
: importBlock(importBlock), context(context), annotator(annotator) {}
|
||||
|
||||
MlirValue importIValue(c10::IValue ivalue);
|
||||
|
||||
|
@ -116,7 +115,6 @@ private:
|
|||
|
||||
MlirBlock importBlock;
|
||||
MlirContext context;
|
||||
TypeMapper typeMapper;
|
||||
ClassAnnotator &annotator;
|
||||
|
||||
// Map tracking already-imported values.
|
||||
|
@ -275,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.ListConstruct", loc,
|
||||
torchMlirTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, list.elementType())),
|
||||
getMlirTypeFromTorchType(loc, list.elementType())),
|
||||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -290,8 +288,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.DictConstruct", loc,
|
||||
torchMlirTorchDictTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, dict.keyType()),
|
||||
typeMapper.mapFromTorchType(loc, dict.valueType())),
|
||||
getMlirTypeFromTorchType(loc, dict.keyType()),
|
||||
getMlirTypeFromTorchType(loc, dict.valueType())),
|
||||
keys, values);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -385,7 +383,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||
|
@ -461,9 +459,8 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
|||
toMlirNamedAttribute(
|
||||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))),
|
||||
toMlirNamedAttribute("type",
|
||||
mlirTypeAttrGet(typeMapper.mapFromTorchType(
|
||||
loc, classAttribute.getType()))),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||
loc, classAttribute.getType()))),
|
||||
isPrivate);
|
||||
}
|
||||
|
||||
|
@ -523,7 +520,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
}
|
||||
|
||||
std::vector<int64_t> shape = *maybeShape;
|
||||
MlirType dtype = TypeMapper(context).mapFromTorchScalarType(
|
||||
MlirType dtype = getMlirTypeForTorchScalarType(
|
||||
mlirLocationUnknownGet(context), *maybeDtype);
|
||||
MlirType typeBound;
|
||||
// `std::vector`'s `.data()` method can return nullptr when the
|
||||
|
|
|
@ -69,7 +69,6 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
|||
}
|
||||
|
||||
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||
TypeMapper typeMapper(context);
|
||||
MlirLocation loc = getMlirLocationFromNode(context, node);
|
||||
auto kind = node->kind();
|
||||
|
||||
|
@ -142,13 +141,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
} else if (output->type()->cast<c10::IntType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.int", loc,
|
||||
typeMapper.mapFromTorchType(loc, output->type()),
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
toMlirNamedAttribute("value",
|
||||
importAttribute(loc, node, c10::attr::value)));
|
||||
} else if (output->type()->cast<c10::FloatType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.float", loc,
|
||||
typeMapper.mapFromTorchType(loc, output->type()),
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
toMlirNamedAttribute("value",
|
||||
importAttribute(loc, node, c10::attr::value)));
|
||||
} else if (output->type()->cast<c10::StringType>()) {
|
||||
|
@ -233,7 +232,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
torch::jit::Block *calleeEntryBlock = function->graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return typeMapper.mapFromTorchType(loc, v->type());
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.CallMethod", loc,
|
||||
|
@ -251,7 +250,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
torch::jit::Block *calleeEntryBlock =
|
||||
functionType->function()->graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return typeMapper.mapFromTorchType(loc, v->type());
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "std.call_indirect", loc,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- ivalue_importer.cpp ------------------------------------------------===//
|
||||
//===- torch_to_mlir_utils.cpp --------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
|
@ -20,28 +20,8 @@
|
|||
|
||||
using namespace torch_mlir;
|
||||
|
||||
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||
auto type = rawMapFromTorchScalarType(scalarType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
throw std::invalid_argument(message.str());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
|
||||
c10::ScalarType scalarType) {
|
||||
auto type = rawMapFromTorchScalarType(scalarType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
mlirEmitError(loc, message.str().c_str());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||
static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
|
||||
c10::ScalarType scalarType) {
|
||||
using c10::ScalarType;
|
||||
switch (scalarType) {
|
||||
case ScalarType::Byte:
|
||||
|
@ -74,6 +54,18 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|||
}
|
||||
}
|
||||
|
||||
MlirType torch_mlir::getMlirTypeForTorchScalarType(MlirLocation loc,
|
||||
c10::ScalarType scalarType) {
|
||||
auto type =
|
||||
getMlirTypeForTorchScalarTypeRaw(mlirLocationGetContext(loc), scalarType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
mlirEmitError(loc, message.str().c_str());
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
// Types (such as `LinearPackedParamsBase`) implemented with the
|
||||
// `torch::CustomClassHolder` mechanism described at
|
||||
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
|
||||
|
@ -119,8 +111,9 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
|||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
|
||||
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType) {
|
||||
MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType) {
|
||||
MlirContext context = mlirLocationGetContext(loc);
|
||||
using c10::TypeKind;
|
||||
auto kind = torchType->kind();
|
||||
switch (kind) {
|
||||
|
@ -129,7 +122,8 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
// Element type.
|
||||
MlirType elementType = {nullptr};
|
||||
if (tensorType->scalarType()) {
|
||||
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
||||
elementType =
|
||||
getMlirTypeForTorchScalarType(loc, *tensorType->scalarType());
|
||||
if (mlirTypeIsNull(elementType))
|
||||
return {nullptr};
|
||||
}
|
||||
|
@ -172,27 +166,27 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
return torchMlirTorchStringTypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return torchMlirTorchOptionalTypeGet(mapFromTorchType(
|
||||
return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
std::vector<MlirType> containedTypes;
|
||||
for (const c10::TypePtr &type :
|
||||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||
containedTypes.push_back(mapFromTorchType(loc, type));
|
||||
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
||||
}
|
||||
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
return torchMlirTorchListTypeGet(mapFromTorchType(
|
||||
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::DictType: {
|
||||
auto dictType = torchType->cast<c10::DictType>();
|
||||
return torchMlirTorchDictTypeGet(
|
||||
mapFromTorchType(loc, dictType->getKeyType()),
|
||||
mapFromTorchType(loc, dictType->getValueType()));
|
||||
getMlirTypeFromTorchType(loc, dictType->getKeyType()),
|
||||
getMlirTypeFromTorchType(loc, dictType->getValueType()));
|
||||
}
|
||||
case TypeKind::NoneType: {
|
||||
return torchMlirTorchNoneTypeGet(context);
|
||||
|
@ -223,29 +217,12 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
}
|
||||
}
|
||||
|
||||
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||
if (!tensor.defined()) {
|
||||
// Undefined tensors are equivalent to None.
|
||||
// This may need to be re-evaluated at some point.
|
||||
return torchMlirTorchNoneTypeGet(context);
|
||||
}
|
||||
|
||||
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
||||
// TODO: Decide when it is necessary to take strides into account. Right now,
|
||||
// just erase them and let the compiler decide.
|
||||
|
||||
auto sizes = tensor.sizes();
|
||||
return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
|
||||
sizes.data(), elementType);
|
||||
}
|
||||
|
||||
MlirType
|
||||
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema) {
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
TypeMapper typeMapper(context);
|
||||
auto mapType = [&](const c10::TypePtr &torchType) {
|
||||
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
|
||||
MlirType type = getMlirTypeFromTorchType(loc, torchType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream msg;
|
||||
msg << "unsupported type in function schema: '"
|
||||
|
@ -267,8 +244,6 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
|||
|
||||
MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
MlirLocation loc) {
|
||||
MlirContext context = mlirLocationGetContext(loc);
|
||||
TypeMapper typeMapper(context);
|
||||
using at::ScalarType;
|
||||
|
||||
auto throwUnsupportedTensorError = [&]() {
|
||||
|
@ -292,8 +267,8 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
|||
// quantized types it might differ (e.g. QInt8 becomes Char). Caller code is
|
||||
// responsible for materializing the proper op that incorporates the
|
||||
// quantization scheme to create a tensor of e.g. `!torch.qint8` element type.
|
||||
MlirType elementType = typeMapper.mapFromTorchScalarType(
|
||||
c10::toUnderlying(tensor.scalar_type()));
|
||||
MlirType elementType = getMlirTypeForTorchScalarType(
|
||||
loc, c10::toUnderlying(tensor.scalar_type()));
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||
loc, shape.size(), shape.data(), elementType, {nullptr});
|
||||
|
@ -378,10 +353,9 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
|||
std::vector<MlirType>
|
||||
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
||||
c10::ArrayRef<torch::jit::Value *> values) {
|
||||
TypeMapper typeMapper(mlirLocationGetContext(loc));
|
||||
std::vector<MlirType> ret;
|
||||
for (auto value : values) {
|
||||
MlirType t = typeMapper.mapFromTorchType(loc, value->type());
|
||||
MlirType t = getMlirTypeFromTorchType(loc, value->type());
|
||||
if (mlirTypeIsNull(t))
|
||||
throw mlir_diagnostic_emitted("unsupported type");
|
||||
ret.push_back(t);
|
||||
|
|
|
@ -20,44 +20,22 @@
|
|||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Maps various runtime types to MlirType.
|
||||
class TypeMapper {
|
||||
public:
|
||||
TypeMapper(MlirContext context) : context(context) {}
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
|
||||
/// type universe from the Python-based types modeled with `c10::Type`.
|
||||
/// Compared to the Python types, which just have `int` and `float` and
|
||||
/// `bool` numeric types, ScalarType has detailed bit-width and precision
|
||||
/// considerations (which matter a lot for tensors, but don't really matter
|
||||
/// for Python code).
|
||||
///
|
||||
/// Returns a null type on failure and emits a diagnostic.
|
||||
MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
|
||||
c10::ScalarType scalarType);
|
||||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// `c10::`ScalarType` is used to represent tensor dtypes, and is a different
|
||||
/// type universe from the Python-based types modeled with `c10::Type`.
|
||||
/// Compared to the Python types, which just have `int` and `float` and
|
||||
/// `bool` numeric types, ScalarType has detailed bit-width and precision
|
||||
/// considerations (which matter a lot for tensors, but don't really matter
|
||||
/// for Python code).
|
||||
///
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType mapFromTorchScalarType(c10::ScalarType scalarType);
|
||||
|
||||
/// Gets a corresponding MlirType for the forward component of a tensor.
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType forwardTensorToType(at::Tensor tensor);
|
||||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// Torch ScalarType is used to represent the possible element types of Torch
|
||||
/// tensors, which is different from the set of types used to represent
|
||||
/// Python numeric scalar values (which are always either f64 or !torch.int).
|
||||
///
|
||||
/// Returns a null type on failure and emits a diagnostic.
|
||||
MlirType mapFromTorchScalarType(MlirLocation loc, c10::ScalarType scalarType);
|
||||
|
||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||
/// on failure and emits a diagnostic.
|
||||
MlirType mapFromTorchType(MlirLocation loc, const c10::TypePtr &torchType);
|
||||
|
||||
private:
|
||||
/// Maps from a scalar type and returns null if no match (no other error
|
||||
/// reporting).
|
||||
MlirType rawMapFromTorchScalarType(c10::ScalarType scalarType);
|
||||
MlirContext context;
|
||||
};
|
||||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||
/// on failure and emits a diagnostic.
|
||||
MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType);
|
||||
|
||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||
///
|
||||
|
|
Loading…
Reference in New Issue