From bfc3ee35c67f56ac7473c52eeee5eb72cde690b7 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 7 Aug 2021 22:33:39 -0400 Subject: [PATCH] Import Machine Translation model to MLIR. This includes the following changes to import MT model into MLIR. There are still a lot of work to for actual compilation. - Add `torch.dict<>`, `torch.any`, `torch.number` types - Add `torch.prim.DictConstruct` op - Fix `torch.prim.TupleConstruct` op assembly format to include resulting types --- build_tools/update_torch_ods.sh | 3 +- .../pytorch/csrc/builder/ivalue_importer.cpp | 18 ++++- .../pytorch/csrc/builder/node_importer.cpp | 40 ++++++++-- .../csrc/builder/torch_to_mlir_utils.cpp | 63 +++++++++------ frontends/pytorch/test/ivalue_import/dict.py | 39 ++++++++++ frontends/pytorch/test/node_import/dict.py | 43 +++++++++++ frontends/pytorch/test/node_import/errors.py | 14 +++- frontends/pytorch/test/node_import/tuple.py | 53 +++++++++++-- include/npcomp-c/TorchTypes.h | 31 ++++++++ include/npcomp/Dialect/Torch/IR/TorchOps.td | 32 +++++++- include/npcomp/Dialect/Torch/IR/TorchTypes.td | 76 +++++++++++++++++-- lib/CAPI/TorchTypes.cpp | 36 +++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 45 ++++++++++- 13 files changed, 444 insertions(+), 49 deletions(-) create mode 100644 frontends/pytorch/test/ivalue_import/dict.py create mode 100644 frontends/pytorch/test/node_import/dict.py diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index 9e7c865eb..e0c1dfcf4 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -6,8 +6,7 @@ src_dir="$(realpath $(dirname $0)/..)" build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")" torch_ir_dir="${src_dir}/include/npcomp/Dialect/Torch/IR" -export PYTHONPATH="${build_dir}/python" - +source $src_dir/.env #ninja -C "${build_dir}" python -m torch_mlir_utils.codegen.torch_ods_gen \ --torch_ir_dir="${torch_ir_dir}" \ diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index 8b87ed873..dcf8c7f11 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -48,7 +48,7 @@ using namespace torch_mlir; namespace { struct IValueHasher { size_t operator()(const c10::IValue &ivalue) const { - if (ivalue.isObject() || ivalue.isList()) { + if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { return std::hash()( static_cast(ivalue.internalToPointer())); } @@ -278,6 +278,22 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { elems); return mlirOperationGetResult(operation, 0); } + if (ivalue.isGenericDict()) { + c10::Dict dict = ivalue.toGenericDict(); + std::vector keys; + std::vector values; + for (auto it = dict.begin(); it != dict.end(); it++) { + keys.push_back(importIValue(it->key())); + values.push_back(importIValue(it->value())); + } + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "torch.prim.DictConstruct", loc, + npcompTorchDictTypeGet( + typeMapper.mapFromTorchType(loc, dict.keyType()), + typeMapper.mapFromTorchType(loc, dict.valueType())), + keys, values); + return mlirOperationGetResult(operation, 0); + } if (ivalue.isTuple()) { auto list = ivalue.toTuple()->elements(); std::vector operands; diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index b8b85950f..ac3464bf7 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -44,18 +44,45 @@ private: }; } // namespace +using InputsTransformFn = + std::function(std::vector &)>; + +// The inputs of `DictConstruct` in TorchScript IR are in the order +// like k0, v0, k1, v1. Rearrange them to put the key operands together and +// then the value operands like k0, k1,v0, v1. This is the expected format by +// the corresponding MLIR op. +static std::vector +rearrangeDictConstructInputs(std::vector &inputs) { + if (inputs.empty()) + return inputs; + assert(inputs.size() % 2 == 0 && + "DictConstruct must have even number of operands"); + + std::vector rearranged; + std::vector values; + for (auto it = inputs.begin(); it != inputs.end(); it++) { + rearranged.push_back(*it); + values.push_back(*++it); + } + rearranged.insert(rearranged.end(), values.begin(), values.end()); + return rearranged; +} + void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { TypeMapper typeMapper(context); MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); - auto createAndMapTrivialNode = [&](Node *node, const std::string &opName) { + auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, + InputsTransformFn t) { + std::vector mappedInputs = lookupMappedValues(node->inputs()); MlirOperation operation = createMlirOperationAtEnd(appendToBlock, opName, loc, getMlirTypesFromValues(loc, node->outputs()), - lookupMappedValues(node->inputs())); + t ? t(mappedInputs) : mappedInputs); mapResults(node, operation); }; + auto createAndMapNodeWithAttribute = [&](Node *node, const std::string &opName, const std::string &attrName, @@ -80,12 +107,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } // Builtin interpreter ops with no operator/schema. + InputsTransformFn transformer = + kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs; switch (kind) { case c10::prim::ListUnpack: case c10::prim::ListConstruct: - case c10::prim::TupleConstruct: { - createAndMapTrivialNode(node, - "torch.prim." + std::string(kind.toUnqualString())); + case c10::prim::TupleConstruct: + case c10::prim::DictConstruct: { + createAndMapTrivialNode( + node, "torch.prim." + std::string(kind.toUnqualString()), transformer); return; } case c10::prim::GetAttr: diff --git a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp index ec7b733d3..e127c48e0 100644 --- a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp +++ b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp @@ -155,35 +155,24 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, /*optionalDtype=*/ elementType); } - case TypeKind::ClassType: { - const c10::ClassTypePtr &classType = torchType->cast(); - MlirType customClassType = mapCustomClassType(context, loc, classType); - if (!mlirTypeIsNull(customClassType)) { - return customClassType; - } - auto maybeName = classType->name(); - std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class"; - return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name)); + case TypeKind::IntType: { + return npcompTorchIntTypeGet(context); } case TypeKind::FloatType: { return npcompTorchFloatTypeGet(context); } - case TypeKind::OptionalType: { - return npcompTorchOptionalTypeGet(mapFromTorchType( - loc, torchType->cast()->getElementType())); - } - case TypeKind::IntType: { - return npcompTorchIntTypeGet(context); - } - case TypeKind::NoneType: { - return npcompTorchNoneTypeGet(context); - } case TypeKind::BoolType: { return npcompTorchBoolTypeGet(context); } - case TypeKind::ListType: { - return npcompTorchListTypeGet(mapFromTorchType( - loc, torchType->cast()->getElementType())); + case TypeKind::NumberType: { + return npcompTorchNumberTypeGet(context); + } + case TypeKind::StringType: { + return npcompTorchStringTypeGet(context); + } + case TypeKind::OptionalType: { + return npcompTorchOptionalTypeGet(mapFromTorchType( + loc, torchType->cast()->getElementType())); } case TypeKind::TupleType: { std::vector containedTypes; @@ -194,8 +183,32 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, return npcompTorchTupleTypeGet(context, containedTypes.size(), containedTypes.data()); } - case TypeKind::StringType: { - return npcompTorchStringTypeGet(context); + case TypeKind::ListType: { + return npcompTorchListTypeGet(mapFromTorchType( + loc, torchType->cast()->getElementType())); + } + case TypeKind::DictType: { + auto dictType = torchType->cast(); + return npcompTorchDictTypeGet( + mapFromTorchType(loc, dictType->getKeyType()), + mapFromTorchType(loc, dictType->getValueType())); + } + case TypeKind::NoneType: { + return npcompTorchNoneTypeGet(context); + } + case TypeKind::AnyType: { + auto anyType = torchType->cast(); + return npcompTorchAnyTypeGet(context); + } + case TypeKind::ClassType: { + const c10::ClassTypePtr &classType = torchType->cast(); + MlirType customClassType = mapCustomClassType(context, loc, classType); + if (!mlirTypeIsNull(customClassType)) { + return customClassType; + } + auto maybeName = classType->name(); + std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class"; + return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name)); } case TypeKind::DeviceObjType: { return npcompTorchDeviceTypeGet(context); @@ -235,7 +248,7 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context, if (mlirTypeIsNull(type)) { std::stringstream msg; msg << "unsupported type in function schema: '" - << c10::toString(torchType) << "'"; + << c10::toString(torchType) << "'"; throw std::invalid_argument(msg.str()); } return type; diff --git a/frontends/pytorch/test/ivalue_import/dict.py b/frontends/pytorch/test/ivalue_import/dict.py new file mode 100644 index 000000000..b90988d59 --- /dev/null +++ b/frontends/pytorch/test/ivalue_import/dict.py @@ -0,0 +1,39 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +from typing import Dict, Optional + +import torch +import torch_mlir + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = {"key1": torch.tensor(1)} + + +# CHECK: torch.class_type @[[CLASSTYPE:.*]] { +# CHECK: torch.attr "training" : !torch.bool +# CHECK: torch.attr "_is_full_backward_hook" : !torch.optional +# CHECK: torch.attr "d" : !torch.dict +# CHECK: } +# CHECK: %[[K:.*]] = torch.constant.str "key1" +# CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1> : tensor) : !torch.tensor<[],si64> +# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct +# CHECK-SAME keys(%[[K]] : !torch.str) values(%[[TENSOR]] : !torch.tensor<[],si64>) +# CHECK-SAME: -> !torch.dict +# CHECK: torch.nn_module { +# CHECK: torch.slot "d", %[[DICT]] : !torch.dict +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> + +test_module = TestModule() +recursivescriptmodule = torch.jit.script(test_module) +# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. +mb.import_module(recursivescriptmodule._c) +mb.module.operation.print() diff --git a/frontends/pytorch/test/node_import/dict.py b/frontends/pytorch/test/node_import/dict.py new file mode 100644 index 000000000..4e83197a3 --- /dev/null +++ b/frontends/pytorch/test/node_import/dict.py @@ -0,0 +1,43 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir +import collections +from typing import Tuple, Optional, List, NamedTuple, Dict + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + + +# CHECK-LABEL: builtin.func @__torch__.dict_literal_empty() -> !torch.dict { +# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict +# CHECK: return %[[DICT]] : !torch.dict +@mb.import_function +@torch.jit.script +def dict_literal_empty() -> Dict[str, torch.Tensor]: + return {} + + +# CHECK-LABEL: builtin.func @__torch__.dict_literal( +# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor, +# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor) +# CHECK-SAME: -> !torch.dict> { +# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct +# CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str) +# CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor) -> +# CHECK-SAME: !torch.dict> +# CHECK: return %[[DICT]] : !torch.dict> +# CHECK: } +@mb.import_function +@torch.jit.script +def dict_literal(k0: str, v0, k1: str, + v1) -> Dict[str, Optional[torch.Tensor]]: + my_dict: Dict[str, Optional[torch.Tensor]] = {k0: v0, k1: v1} + return my_dict + + +mb.module.operation.print() +print() diff --git a/frontends/pytorch/test/node_import/errors.py b/frontends/pytorch/test/node_import/errors.py index 43c7445c6..8de43be44 100644 --- a/frontends/pytorch/test/node_import/errors.py +++ b/frontends/pytorch/test/node_import/errors.py @@ -2,23 +2,31 @@ # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. -import typing +import enum import torch import torch_mlir + +class Color(enum.Enum): + RED = 1 + GREEN = 2 + + # RUN: %PYTHON %s mb = torch_mlir.ModuleBuilder() # To test errors, use a type that we don't support yet. try: + @mb.import_function @torch.jit.script - def import_class(x: typing.Any): + def import_class(x: Color): return x except Exception as e: # TODO: Once diagnostics are enabled, verify the actual error emitted. - assert str(e) == "unsupported type in function schema: 'Any'" + assert str( + e) == "unsupported type in function schema: 'Enum<__torch__.Color>'" else: assert False, "Expected exception" diff --git a/frontends/pytorch/test/node_import/tuple.py b/frontends/pytorch/test/node_import/tuple.py index 1b7b07798..f9539cf5c 100644 --- a/frontends/pytorch/test/node_import/tuple.py +++ b/frontends/pytorch/test/node_import/tuple.py @@ -4,22 +4,65 @@ import torch import torch_mlir +import collections +from typing import Tuple, Optional, NamedTuple # RUN: %PYTHON %s | npcomp-opt | FileCheck %s mb = torch_mlir.ModuleBuilder() +NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]), + ('f2', Optional[torch.Tensor])]) -# CHECK-LABEL: func @__torch__.f( +# CHECK-LABEL: builtin.func @__torch__.tuple( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, -# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple { -# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor +# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> +# CHECK-SAME: !torch.tuple { +# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : +# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple # CHECK: return %[[RET]] : !torch.tuple + @mb.import_function @torch.jit.script -def f(t0, t1): +def tuple(t0, t1): return t0, t1 -assert isinstance(f, torch.jit.ScriptFunction) + +# CHECK-LABEL: builtin.func @__torch__.tuple_optional( +# CHECK-SAME: %[[T0:.*]]: !torch.tensor, +# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> +# CHECK-SAME: !torch.tuple, !torch.optional> { +# CHECK: %[[TNEW:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : +# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple +# CHECK: %[[RET:.*]] = torch.derefine %[[TNEW]] : +# CHECK-SAME: !torch.tuple to +# CHECK-SAME: !torch.tuple, !torch.optional> +# CHECK: return %[[RET]] : !torch.tuple, !torch.optional> + + +@mb.import_function +@torch.jit.script +def tuple_optional( + t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + return t0, t1 + + +# CHECK-LABEL: builtin.func @__torch__.namedtuple_optional( +# CHECK-SAME: %[[T0:.*]]: !torch.tensor, +# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> +# CHECK-SAME: !torch.tuple, !torch.optional> { +# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : +# CHECK-SAME: !torch.tensor, !torch.tensor -> +# CHECK-SAME: !torch.tuple, !torch.optional> +# CHECK: return %[[RET]] : !torch.tuple, !torch.optional> +# CHECK: } +# +@mb.import_function +@torch.jit.script +def namedtuple_optional( + t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + return NT(t0, t1) + + mb.module.operation.print() print() diff --git a/include/npcomp-c/TorchTypes.h b/include/npcomp-c/TorchTypes.h index 748c2d324..3f832b664 100644 --- a/include/npcomp-c/TorchTypes.h +++ b/include/npcomp-c/TorchTypes.h @@ -190,6 +190,37 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t); /// Gets the !torch.str type. MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context); +//===----------------------------------------------------------------------===// +// !torch.any type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.any type. +MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t); + +/// Gets the !torch.str type. +MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.number type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.number type. +MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t); + +/// Gets the !torch.number type. +MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.dict type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.dict type. +MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t); + +/// Gets the !torch.dict type. +MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType, + MlirType valueType); + #ifdef __cplusplus } #endif diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index a75474116..9273c0294 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -326,7 +326,8 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ NoSideEffect, TypesMatchWith<"contained types correspond to operand types", - "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))"> + "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))", + "isValidSubtype"> ]> { let summary = "TorchScript prim::TupleConstruct op"; let description = [{ @@ -342,7 +343,7 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ ); let assemblyFormat = [{ - $elements attr-dict `:` type($elements) + $elements attr-dict `:` type($elements) `->` type($result) }]; } @@ -366,6 +367,33 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ }]; } +def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ + AllowsTypeRefinement, + SameVariadicOperandSize, + ]> { + let summary = "TorchScript prim::DictConstruct op"; + + let arguments = (ins + Variadic:$keys, + Variadic:$values + ); + + let results = (outs + Torch_DictType:$result + ); + + let verifier = "return ::verify(*this);"; + + let assemblyFormat = [{ + `keys` `(` ($keys^ `:` type($keys))? `)` `values` `(` ($values^ `:` type($values))? `)` attr-dict `->` type($result) + }]; + + let extraClassDeclaration = [{ + Type getKeyType() { return getType().cast().getKeyType(); } + Type getValueType() { return getType().cast().getValueType(); } + }]; +} + def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> { let summary = "TorchScript prim::GetAttr op"; diff --git a/include/npcomp/Dialect/Torch/IR/TorchTypes.td b/include/npcomp/Dialect/Torch/IR/TorchTypes.td index 1d038fb03..9d8cfc003 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchTypes.td +++ b/include/npcomp/Dialect/Torch/IR/TorchTypes.td @@ -321,6 +321,56 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { }]; } +def Torch_AnyType : Torch_Type<"Any", "any"> { + let summary = "Torch any type"; + let description = [{ + Represent any torch type. All the other types are sub types of Any type. + }]; +} + +def Torch_NumberType : Torch_Type<"Number", "number"> { + let summary = "Torch number type"; + let description = [{ + The Int, Float and Complex type are sub types of Number type. + }]; +} + +def Torch_DictType : Torch_Type<"Dict", "dict"> { + + let summary = "!torch.dict[KT, VT] "; + let parameters = (ins "::mlir::Type":$keyType, "::mlir::Type":$valueType); + let description = [{ + Torch Dict type with key and value type. + }]; + + let printer = [{ + $_printer << getMnemonic() << "<" << getImpl()->keyType << ", " << getImpl()->valueType << ">"; + }]; + + let parser = [{ + if (parser.parseLess()) + return Type(); + Type keyType; + if ($_parser.parseType(keyType)) + return Type(); + if ($_parser.parseComma()) + return Type(); + Type valueType; + if ($_parser.parseType(valueType)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, keyType, valueType); + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType, + "::mlir::Type":$valueType), [{ + return Base::get(keyType.getContext(), keyType, valueType); + }]> + ]; +} + //===----------------------------------------------------------------------===// // Type predicates //===----------------------------------------------------------------------===// @@ -348,17 +398,33 @@ def AnyTorchScalarType : AnyTypeOf<[ Torch_BoolType, ], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">; +// See function `DictTypePtr create(TypePtr key, TypePtr value)` +// in aten/src/ATen/core/jit_type.h. +def AnyTorchDictKeyType : AnyTypeOf<[ + Torch_AnyType, + Torch_IntType, + Torch_BoolType, + Torch_FloatType, + Torch_StringType, + Torch_FloatType, + AnyTorchTensorType, +], "Allowed dict key types">; + +// In alphabetic order. def AnyTorchType : AnyTypeOf<[ AnyTorchScalarType, AnyTorchTensorType, - Torch_TupleType, - Torch_StringType, + Torch_AnyType, + Torch_DictType, + Torch_DeviceType, + Torch_ListType, + Torch_LinearParamsType, + Torch_NumberType, Torch_NnModuleType, Torch_NoneType, Torch_OptionalType, - Torch_ListType, - Torch_DeviceType, - Torch_LinearParamsType, + Torch_StringType, + Torch_TupleType, ], "Any type that is legal to pass to a Torch kernel">; def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">; diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index d5c04751c..cb455f980 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -225,3 +225,39 @@ bool npcompTypeIsATorchString(MlirType t) { MlirType npcompTorchStringTypeGet(MlirContext context) { return wrap(Torch::StringType::get(unwrap(context))); } + +//===----------------------------------------------------------------------===// +// torch.any type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchAny(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchAnyTypeGet(MlirContext context) { + return wrap(Torch::AnyType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// torch.number type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchNumber(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchNumberTypeGet(MlirContext context) { + return wrap(Torch::NumberType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// torch.Dict type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchDict(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) { + return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType))); +} diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5beac14aa..cd28a4a74 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -93,9 +93,31 @@ static LogicalResult verify(NnModuleOp op) { bool isValidSubtype(Type subtype, Type type) { if (subtype == type) return true; + + if (auto any = type.dyn_cast()) + return true; + + if (auto number = type.dyn_cast()) + return subtype.isa() || subtype.isa(); + if (auto optional = type.dyn_cast()) return isValidSubtype(subtype, optional.getContainedType()) || subtype.isa(); + + if (auto tuple = type.dyn_cast()) { + if (!subtype.isa()) + return false; + auto subtypes = subtype.cast().getContainedTypes(); + auto types = tuple.getContainedTypes(); + if (subtypes.size() != types.size()) + return false; + for (auto t : llvm::zip(subtypes, types)) { + if (!isValidSubtype(std::get<0>(t), std::get<1>(t))) + return false; + } + return true; + } + // TODO: This is not subtyping according to PEP 483. See description // of NonValueTensorType. if (subtype.isa() && type.isa() && @@ -142,7 +164,7 @@ static LogicalResult verify(PrimListConstructOp op) { auto resultType = op.getResult().getType(); auto resultElementType = resultType.dyn_cast().getContainedType(); auto matchResultElementType = [&](Type type) { - return type.getTypeID() == resultElementType.getTypeID(); + return isValidSubtype(type, resultElementType); }; if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) { return op.emitError() << "operand types should have the same type as the " @@ -152,6 +174,27 @@ static LogicalResult verify(PrimListConstructOp op) { return success(); } +//===----------------------------------------------------------------------===// +// PrimDictConstructOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(PrimDictConstructOp op) { + auto isValidSubTypeOf = [](Type expectedType) { + return [=](Type type) { return isValidSubtype(type, expectedType); }; + }; + + Type keyType = op.getKeyType(); + if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) { + return op.emitError() << "keys should be of Dict key type"; + } + + Type valueType = op.getValueType(); + if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) { + return op.emitError() << "values should be of Dict value type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // ClassTypeOp //===----------------------------------------------------------------------===//