Import Machine Translation model to MLIR.

This includes the following changes to import MT model into MLIR. There
are still a lot of work to for actual compilation.
- Add `torch.dict<>`, `torch.any`, `torch.number` types
- Add `torch.prim.DictConstruct` op
- Fix `torch.prim.TupleConstruct` op assembly format to include resulting types
pull/276/head
Yi Zhang 2021-08-07 22:33:39 -04:00
parent a3bfd115ee
commit bfc3ee35c6
13 changed files with 444 additions and 49 deletions

View File

@ -6,8 +6,7 @@ src_dir="$(realpath $(dirname $0)/..)"
build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")"
torch_ir_dir="${src_dir}/include/npcomp/Dialect/Torch/IR"
export PYTHONPATH="${build_dir}/python"
source $src_dir/.env
#ninja -C "${build_dir}"
python -m torch_mlir_utils.codegen.torch_ods_gen \
--torch_ir_dir="${torch_ir_dir}" \

View File

@ -48,7 +48,7 @@ using namespace torch_mlir;
namespace {
struct IValueHasher {
size_t operator()(const c10::IValue &ivalue) const {
if (ivalue.isObject() || ivalue.isList()) {
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
return std::hash<const void *>()(
static_cast<const void *>(ivalue.internalToPointer()));
}
@ -278,6 +278,22 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isGenericDict()) {
c10::Dict<c10::IValue, c10::IValue> dict = ivalue.toGenericDict();
std::vector<MlirValue> keys;
std::vector<MlirValue> values;
for (auto it = dict.begin(); it != dict.end(); it++) {
keys.push_back(importIValue(it->key()));
values.push_back(importIValue(it->value()));
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc,
npcompTorchDictTypeGet(
typeMapper.mapFromTorchType(loc, dict.keyType()),
typeMapper.mapFromTorchType(loc, dict.valueType())),
keys, values);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTuple()) {
auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> operands;

View File

@ -44,18 +44,45 @@ private:
};
} // namespace
using InputsTransformFn =
std::function<std::vector<MlirValue>(std::vector<MlirValue> &)>;
// The inputs of `DictConstruct` in TorchScript IR are in the order
// like k0, v0, k1, v1. Rearrange them to put the key operands together and
// then the value operands like k0, k1,v0, v1. This is the expected format by
// the corresponding MLIR op.
static std::vector<MlirValue>
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
if (inputs.empty())
return inputs;
assert(inputs.size() % 2 == 0 &&
"DictConstruct must have even number of operands");
std::vector<MlirValue> rearranged;
std::vector<MlirValue> values;
for (auto it = inputs.begin(); it != inputs.end(); it++) {
rearranged.push_back(*it);
values.push_back(*++it);
}
rearranged.insert(rearranged.end(), values.begin(), values.end());
return rearranged;
}
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
TypeMapper typeMapper(context);
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName) {
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
t ? t(mappedInputs) : mappedInputs);
mapResults(node, operation);
};
auto createAndMapNodeWithAttribute = [&](Node *node,
const std::string &opName,
const std::string &attrName,
@ -80,12 +107,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
// Builtin interpreter ops with no operator/schema.
InputsTransformFn transformer =
kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs;
switch (kind) {
case c10::prim::ListUnpack:
case c10::prim::ListConstruct:
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()));
case c10::prim::TupleConstruct:
case c10::prim::DictConstruct: {
createAndMapTrivialNode(
node, "torch.prim." + std::string(kind.toUnqualString()), transformer);
return;
}
case c10::prim::GetAttr:

View File

@ -155,35 +155,24 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
/*optionalDtype=*/
elementType);
}
case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
MlirType customClassType = mapCustomClassType(context, loc, classType);
if (!mlirTypeIsNull(customClassType)) {
return customClassType;
}
auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
case TypeKind::IntType: {
return npcompTorchIntTypeGet(context);
}
case TypeKind::FloatType: {
return npcompTorchFloatTypeGet(context);
}
case TypeKind::OptionalType: {
return npcompTorchOptionalTypeGet(mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::IntType: {
return npcompTorchIntTypeGet(context);
}
case TypeKind::NoneType: {
return npcompTorchNoneTypeGet(context);
}
case TypeKind::BoolType: {
return npcompTorchBoolTypeGet(context);
}
case TypeKind::ListType: {
return npcompTorchListTypeGet(mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
case TypeKind::NumberType: {
return npcompTorchNumberTypeGet(context);
}
case TypeKind::StringType: {
return npcompTorchStringTypeGet(context);
}
case TypeKind::OptionalType: {
return npcompTorchOptionalTypeGet(mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::TupleType: {
std::vector<MlirType> containedTypes;
@ -194,8 +183,32 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return npcompTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::StringType: {
return npcompTorchStringTypeGet(context);
case TypeKind::ListType: {
return npcompTorchListTypeGet(mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::DictType: {
auto dictType = torchType->cast<c10::DictType>();
return npcompTorchDictTypeGet(
mapFromTorchType(loc, dictType->getKeyType()),
mapFromTorchType(loc, dictType->getValueType()));
}
case TypeKind::NoneType: {
return npcompTorchNoneTypeGet(context);
}
case TypeKind::AnyType: {
auto anyType = torchType->cast<c10::AnyType>();
return npcompTorchAnyTypeGet(context);
}
case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
MlirType customClassType = mapCustomClassType(context, loc, classType);
if (!mlirTypeIsNull(customClassType)) {
return customClassType;
}
auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
}
case TypeKind::DeviceObjType: {
return npcompTorchDeviceTypeGet(context);
@ -235,7 +248,7 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
if (mlirTypeIsNull(type)) {
std::stringstream msg;
msg << "unsupported type in function schema: '"
<< c10::toString(torchType) << "'";
<< c10::toString(torchType) << "'";
throw std::invalid_argument(msg.str());
}
return type;

View File

@ -0,0 +1,39 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
from typing import Dict, Optional
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.d = {"key1": torch.tensor(1)}
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# CHECK: torch.attr "training" : !torch.bool
# CHECK: torch.attr "_is_full_backward_hook" : !torch.optional<!torch.bool>
# CHECK: torch.attr "d" : !torch.dict<!torch.str, !torch.tensor>
# CHECK: }
# CHECK: %[[K:.*]] = torch.constant.str "key1"
# CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1> : tensor<si64>) : !torch.tensor<[],si64>
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
# CHECK-SAME keys(%[[K]] : !torch.str) values(%[[TENSOR]] : !torch.tensor<[],si64>)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
# CHECK: torch.nn_module {
# CHECK: torch.slot "d", %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -0,0 +1,43 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import collections
from typing import Tuple, Optional, List, NamedTuple, Dict
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: builtin.func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<!torch.str, !torch.tensor>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
@mb.import_function
@torch.jit.script
def dict_literal_empty() -> Dict[str, torch.Tensor]:
return {}
# CHECK-LABEL: builtin.func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.optional<!torch.tensor>> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
# CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
# CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor) ->
# CHECK-SAME: !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
# CHECK: }
@mb.import_function
@torch.jit.script
def dict_literal(k0: str, v0, k1: str,
v1) -> Dict[str, Optional[torch.Tensor]]:
my_dict: Dict[str, Optional[torch.Tensor]] = {k0: v0, k1: v1}
return my_dict
mb.module.operation.print()
print()

View File

@ -2,23 +2,31 @@
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import enum
import torch
import torch_mlir
class Color(enum.Enum):
RED = 1
GREEN = 2
# RUN: %PYTHON %s
mb = torch_mlir.ModuleBuilder()
# To test errors, use a type that we don't support yet.
try:
@mb.import_function
@torch.jit.script
def import_class(x: typing.Any):
def import_class(x: Color):
return x
except Exception as e:
# TODO: Once diagnostics are enabled, verify the actual error emitted.
assert str(e) == "unsupported type in function schema: 'Any'"
assert str(
e) == "unsupported type in function schema: 'Enum<__torch__.Color>'"
else:
assert False, "Expected exception"

View File

@ -4,22 +4,65 @@
import torch
import torch_mlir
import collections
from typing import Tuple, Optional, NamedTuple
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
('f2', Optional[torch.Tensor])])
# CHECK-LABEL: func @__torch__.f(
# CHECK-LABEL: builtin.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>
@mb.import_function
@torch.jit.script
def f(t0, t1):
def tuple(t0, t1):
return t0, t1
assert isinstance(f, torch.jit.ScriptFunction)
# CHECK-LABEL: builtin.func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
# CHECK: %[[TNEW:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
# CHECK: %[[RET:.*]] = torch.derefine %[[TNEW]] :
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> to
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
@mb.import_function
@torch.jit.script
def tuple_optional(
t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return t0, t1
# CHECK-LABEL: builtin.func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK: }
#
@mb.import_function
@torch.jit.script
def namedtuple_optional(
t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return NT(t0, t1)
mb.module.operation.print()
print()

View File

@ -190,6 +190,37 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.any type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.any type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.number type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.number type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t);
/// Gets the !torch.number type.
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.dict type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.dict type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t);
/// Gets the !torch.dict type.
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType,
MlirType valueType);
#ifdef __cplusplus
}
#endif

View File

@ -326,7 +326,8 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
NoSideEffect,
TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))">
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
"isValidSubtype">
]> {
let summary = "TorchScript prim::TupleConstruct op";
let description = [{
@ -342,7 +343,7 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
);
let assemblyFormat = [{
$elements attr-dict `:` type($elements)
$elements attr-dict `:` type($elements) `->` type($result)
}];
}
@ -366,6 +367,33 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
}];
}
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
AllowsTypeRefinement,
SameVariadicOperandSize,
]> {
let summary = "TorchScript prim::DictConstruct op";
let arguments = (ins
Variadic<AnyTorchDictKeyType>:$keys,
Variadic<AnyTorchType>:$values
);
let results = (outs
Torch_DictType:$result
);
let verifier = "return ::verify(*this);";
let assemblyFormat = [{
`keys` `(` ($keys^ `:` type($keys))? `)` `values` `(` ($values^ `:` type($values))? `)` attr-dict `->` type($result)
}];
let extraClassDeclaration = [{
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
Type getValueType() { return getType().cast<DictType>().getValueType(); }
}];
}
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
let summary = "TorchScript prim::GetAttr op";

View File

@ -321,6 +321,56 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
}];
}
def Torch_AnyType : Torch_Type<"Any", "any"> {
let summary = "Torch any type";
let description = [{
Represent any torch type. All the other types are sub types of Any type.
}];
}
def Torch_NumberType : Torch_Type<"Number", "number"> {
let summary = "Torch number type";
let description = [{
The Int, Float and Complex type are sub types of Number type.
}];
}
def Torch_DictType : Torch_Type<"Dict", "dict"> {
let summary = "!torch.dict[KT, VT] ";
let parameters = (ins "::mlir::Type":$keyType, "::mlir::Type":$valueType);
let description = [{
Torch Dict type with key and value type.
}];
let printer = [{
$_printer << getMnemonic() << "<" << getImpl()->keyType << ", " << getImpl()->valueType << ">";
}];
let parser = [{
if (parser.parseLess())
return Type();
Type keyType;
if ($_parser.parseType(keyType))
return Type();
if ($_parser.parseComma())
return Type();
Type valueType;
if ($_parser.parseType(valueType))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, keyType, valueType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
"::mlir::Type":$valueType), [{
return Base::get(keyType.getContext(), keyType, valueType);
}]>
];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
@ -348,17 +398,33 @@ def AnyTorchScalarType : AnyTypeOf<[
Torch_BoolType,
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
// in aten/src/ATen/core/jit_type.h.
def AnyTorchDictKeyType : AnyTypeOf<[
Torch_AnyType,
Torch_IntType,
Torch_BoolType,
Torch_FloatType,
Torch_StringType,
Torch_FloatType,
AnyTorchTensorType,
], "Allowed dict key types">;
// In alphabetic order.
def AnyTorchType : AnyTypeOf<[
AnyTorchScalarType,
AnyTorchTensorType,
Torch_TupleType,
Torch_StringType,
Torch_AnyType,
Torch_DictType,
Torch_DeviceType,
Torch_ListType,
Torch_LinearParamsType,
Torch_NumberType,
Torch_NnModuleType,
Torch_NoneType,
Torch_OptionalType,
Torch_ListType,
Torch_DeviceType,
Torch_LinearParamsType,
Torch_StringType,
Torch_TupleType,
], "Any type that is legal to pass to a Torch kernel">;
def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">;

View File

@ -225,3 +225,39 @@ bool npcompTypeIsATorchString(MlirType t) {
MlirType npcompTorchStringTypeGet(MlirContext context) {
return wrap(Torch::StringType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.any type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchAny(MlirType t) {
return unwrap(t).isa<Torch::AnyType>();
}
MlirType npcompTorchAnyTypeGet(MlirContext context) {
return wrap(Torch::AnyType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.number type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNumber(MlirType t) {
return unwrap(t).isa<Torch::NumberType>();
}
MlirType npcompTorchNumberTypeGet(MlirContext context) {
return wrap(Torch::NumberType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.Dict type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDict(MlirType t) {
return unwrap(t).isa<Torch::DictType>();
}
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) {
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
}

View File

@ -93,9 +93,31 @@ static LogicalResult verify(NnModuleOp op) {
bool isValidSubtype(Type subtype, Type type) {
if (subtype == type)
return true;
if (auto any = type.dyn_cast<AnyType>())
return true;
if (auto number = type.dyn_cast<NumberType>())
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
if (auto optional = type.dyn_cast<OptionalType>())
return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>();
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (!subtype.isa<Torch::TupleType>())
return false;
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
auto types = tuple.getContainedTypes();
if (subtypes.size() != types.size())
return false;
for (auto t : llvm::zip(subtypes, types)) {
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
return false;
}
return true;
}
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
@ -142,7 +164,7 @@ static LogicalResult verify(PrimListConstructOp op) {
auto resultType = op.getResult().getType();
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
auto matchResultElementType = [&](Type type) {
return type.getTypeID() == resultElementType.getTypeID();
return isValidSubtype(type, resultElementType);
};
if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) {
return op.emitError() << "operand types should have the same type as the "
@ -152,6 +174,27 @@ static LogicalResult verify(PrimListConstructOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// PrimDictConstructOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(PrimDictConstructOp op) {
auto isValidSubTypeOf = [](Type expectedType) {
return [=](Type type) { return isValidSubtype(type, expectedType); };
};
Type keyType = op.getKeyType();
if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) {
return op.emitError() << "keys should be of Dict key type";
}
Type valueType = op.getValueType();
if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) {
return op.emitError() << "values should be of Dict value type";
}
return success();
}
//===----------------------------------------------------------------------===//
// ClassTypeOp
//===----------------------------------------------------------------------===//