mirror of https://github.com/llvm/torch-mlir
Import Machine Translation model to MLIR.
This includes the following changes to import MT model into MLIR. There are still a lot of work to for actual compilation. - Add `torch.dict<>`, `torch.any`, `torch.number` types - Add `torch.prim.DictConstruct` op - Fix `torch.prim.TupleConstruct` op assembly format to include resulting typespull/276/head
parent
a3bfd115ee
commit
bfc3ee35c6
|
@ -6,8 +6,7 @@ src_dir="$(realpath $(dirname $0)/..)"
|
|||
build_dir="$(realpath "${NPCOMP_BUILD_DIR:-$src_dir/build}")"
|
||||
torch_ir_dir="${src_dir}/include/npcomp/Dialect/Torch/IR"
|
||||
|
||||
export PYTHONPATH="${build_dir}/python"
|
||||
|
||||
source $src_dir/.env
|
||||
#ninja -C "${build_dir}"
|
||||
python -m torch_mlir_utils.codegen.torch_ods_gen \
|
||||
--torch_ir_dir="${torch_ir_dir}" \
|
||||
|
|
|
@ -48,7 +48,7 @@ using namespace torch_mlir;
|
|||
namespace {
|
||||
struct IValueHasher {
|
||||
size_t operator()(const c10::IValue &ivalue) const {
|
||||
if (ivalue.isObject() || ivalue.isList()) {
|
||||
if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) {
|
||||
return std::hash<const void *>()(
|
||||
static_cast<const void *>(ivalue.internalToPointer()));
|
||||
}
|
||||
|
@ -278,6 +278,22 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isGenericDict()) {
|
||||
c10::Dict<c10::IValue, c10::IValue> dict = ivalue.toGenericDict();
|
||||
std::vector<MlirValue> keys;
|
||||
std::vector<MlirValue> values;
|
||||
for (auto it = dict.begin(); it != dict.end(); it++) {
|
||||
keys.push_back(importIValue(it->key()));
|
||||
values.push_back(importIValue(it->value()));
|
||||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.DictConstruct", loc,
|
||||
npcompTorchDictTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, dict.keyType()),
|
||||
typeMapper.mapFromTorchType(loc, dict.valueType())),
|
||||
keys, values);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTuple()) {
|
||||
auto list = ivalue.toTuple()->elements();
|
||||
std::vector<MlirValue> operands;
|
||||
|
|
|
@ -44,18 +44,45 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
using InputsTransformFn =
|
||||
std::function<std::vector<MlirValue>(std::vector<MlirValue> &)>;
|
||||
|
||||
// The inputs of `DictConstruct` in TorchScript IR are in the order
|
||||
// like k0, v0, k1, v1. Rearrange them to put the key operands together and
|
||||
// then the value operands like k0, k1,v0, v1. This is the expected format by
|
||||
// the corresponding MLIR op.
|
||||
static std::vector<MlirValue>
|
||||
rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
||||
if (inputs.empty())
|
||||
return inputs;
|
||||
assert(inputs.size() % 2 == 0 &&
|
||||
"DictConstruct must have even number of operands");
|
||||
|
||||
std::vector<MlirValue> rearranged;
|
||||
std::vector<MlirValue> values;
|
||||
for (auto it = inputs.begin(); it != inputs.end(); it++) {
|
||||
rearranged.push_back(*it);
|
||||
values.push_back(*++it);
|
||||
}
|
||||
rearranged.insert(rearranged.end(), values.begin(), values.end());
|
||||
return rearranged;
|
||||
}
|
||||
|
||||
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||
TypeMapper typeMapper(context);
|
||||
MlirLocation loc = getMlirLocationFromNode(context, node);
|
||||
auto kind = node->kind();
|
||||
|
||||
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName) {
|
||||
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
|
||||
InputsTransformFn t) {
|
||||
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, opName, loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
t ? t(mappedInputs) : mappedInputs);
|
||||
mapResults(node, operation);
|
||||
};
|
||||
|
||||
auto createAndMapNodeWithAttribute = [&](Node *node,
|
||||
const std::string &opName,
|
||||
const std::string &attrName,
|
||||
|
@ -80,12 +107,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
// Builtin interpreter ops with no operator/schema.
|
||||
InputsTransformFn transformer =
|
||||
kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs;
|
||||
switch (kind) {
|
||||
case c10::prim::ListUnpack:
|
||||
case c10::prim::ListConstruct:
|
||||
case c10::prim::TupleConstruct: {
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()));
|
||||
case c10::prim::TupleConstruct:
|
||||
case c10::prim::DictConstruct: {
|
||||
createAndMapTrivialNode(
|
||||
node, "torch.prim." + std::string(kind.toUnqualString()), transformer);
|
||||
return;
|
||||
}
|
||||
case c10::prim::GetAttr:
|
||||
|
|
|
@ -155,35 +155,24 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||
MlirType customClassType = mapCustomClassType(context, loc, classType);
|
||||
if (!mlirTypeIsNull(customClassType)) {
|
||||
return customClassType;
|
||||
}
|
||||
auto maybeName = classType->name();
|
||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
case TypeKind::IntType: {
|
||||
return npcompTorchIntTypeGet(context);
|
||||
}
|
||||
case TypeKind::FloatType: {
|
||||
return npcompTorchFloatTypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::IntType: {
|
||||
return npcompTorchIntTypeGet(context);
|
||||
}
|
||||
case TypeKind::NoneType: {
|
||||
return npcompTorchNoneTypeGet(context);
|
||||
}
|
||||
case TypeKind::BoolType: {
|
||||
return npcompTorchBoolTypeGet(context);
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
return npcompTorchListTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
case TypeKind::NumberType: {
|
||||
return npcompTorchNumberTypeGet(context);
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompTorchStringTypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
std::vector<MlirType> containedTypes;
|
||||
|
@ -194,8 +183,32 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
return npcompTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompTorchStringTypeGet(context);
|
||||
case TypeKind::ListType: {
|
||||
return npcompTorchListTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::DictType: {
|
||||
auto dictType = torchType->cast<c10::DictType>();
|
||||
return npcompTorchDictTypeGet(
|
||||
mapFromTorchType(loc, dictType->getKeyType()),
|
||||
mapFromTorchType(loc, dictType->getValueType()));
|
||||
}
|
||||
case TypeKind::NoneType: {
|
||||
return npcompTorchNoneTypeGet(context);
|
||||
}
|
||||
case TypeKind::AnyType: {
|
||||
auto anyType = torchType->cast<c10::AnyType>();
|
||||
return npcompTorchAnyTypeGet(context);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||
MlirType customClassType = mapCustomClassType(context, loc, classType);
|
||||
if (!mlirTypeIsNull(customClassType)) {
|
||||
return customClassType;
|
||||
}
|
||||
auto maybeName = classType->name();
|
||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
}
|
||||
case TypeKind::DeviceObjType: {
|
||||
return npcompTorchDeviceTypeGet(context);
|
||||
|
@ -235,7 +248,7 @@ torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
|||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream msg;
|
||||
msg << "unsupported type in function schema: '"
|
||||
<< c10::toString(torchType) << "'";
|
||||
<< c10::toString(torchType) << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return type;
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.d = {"key1": torch.tensor(1)}
|
||||
|
||||
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# CHECK: torch.attr "training" : !torch.bool
|
||||
# CHECK: torch.attr "_is_full_backward_hook" : !torch.optional<!torch.bool>
|
||||
# CHECK: torch.attr "d" : !torch.dict<!torch.str, !torch.tensor>
|
||||
# CHECK: }
|
||||
# CHECK: %[[K:.*]] = torch.constant.str "key1"
|
||||
# CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1> : tensor<si64>) : !torch.tensor<[],si64>
|
||||
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
|
||||
# CHECK-SAME keys(%[[K]] : !torch.str) values(%[[TENSOR]] : !torch.tensor<[],si64>)
|
||||
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "d", %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c)
|
||||
mb.module.operation.print()
|
|
@ -0,0 +1,43 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
import collections
|
||||
from typing import Tuple, Optional, List, NamedTuple, Dict
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: builtin.func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
|
||||
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<!torch.str, !torch.tensor>
|
||||
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def dict_literal_empty() -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
|
||||
# CHECK-LABEL: builtin.func @__torch__.dict_literal(
|
||||
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
|
||||
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.optional<!torch.tensor>> {
|
||||
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
|
||||
# CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
|
||||
# CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor) ->
|
||||
# CHECK-SAME: !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
|
||||
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
|
||||
# CHECK: }
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def dict_literal(k0: str, v0, k1: str,
|
||||
v1) -> Dict[str, Optional[torch.Tensor]]:
|
||||
my_dict: Dict[str, Optional[torch.Tensor]] = {k0: v0, k1: v1}
|
||||
return my_dict
|
||||
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -2,23 +2,31 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
|
||||
class Color(enum.Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
|
||||
# RUN: %PYTHON %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# To test errors, use a type that we don't support yet.
|
||||
try:
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def import_class(x: typing.Any):
|
||||
def import_class(x: Color):
|
||||
return x
|
||||
except Exception as e:
|
||||
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||
assert str(e) == "unsupported type in function schema: 'Any'"
|
||||
assert str(
|
||||
e) == "unsupported type in function schema: 'Enum<__torch__.Color>'"
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
|
|
@ -4,22 +4,65 @@
|
|||
|
||||
import torch
|
||||
import torch_mlir
|
||||
import collections
|
||||
from typing import Tuple, Optional, NamedTuple
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
|
||||
('f2', Optional[torch.Tensor])])
|
||||
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-LABEL: builtin.func @__torch__.tuple(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
|
||||
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def f(t0, t1):
|
||||
def tuple(t0, t1):
|
||||
return t0, t1
|
||||
|
||||
assert isinstance(f, torch.jit.ScriptFunction)
|
||||
|
||||
# CHECK-LABEL: builtin.func @__torch__.tuple_optional(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
|
||||
# CHECK: %[[TNEW:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
|
||||
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
# CHECK: %[[RET:.*]] = torch.derefine %[[TNEW]] :
|
||||
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> to
|
||||
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
|
||||
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def tuple_optional(
|
||||
t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return t0, t1
|
||||
|
||||
|
||||
# CHECK-LABEL: builtin.func @__torch__.namedtuple_optional(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
|
||||
# CHECK-SAME: !torch.tensor, !torch.tensor ->
|
||||
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
|
||||
# CHECK: }
|
||||
#
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def namedtuple_optional(
|
||||
t0, t1) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return NT(t0, t1)
|
||||
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -190,6 +190,37 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t);
|
|||
/// Gets the !torch.str type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.any type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.any type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t);
|
||||
|
||||
/// Gets the !torch.str type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.number type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.number type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t);
|
||||
|
||||
/// Gets the !torch.number type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.dict type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.dict type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t);
|
||||
|
||||
/// Gets the !torch.dict type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType,
|
||||
MlirType valueType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -326,7 +326,8 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
|||
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"contained types correspond to operand types",
|
||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))">
|
||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
|
||||
"isValidSubtype">
|
||||
]> {
|
||||
let summary = "TorchScript prim::TupleConstruct op";
|
||||
let description = [{
|
||||
|
@ -342,7 +343,7 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
|||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` type($elements)
|
||||
$elements attr-dict `:` type($elements) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -366,6 +367,33 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
||||
AllowsTypeRefinement,
|
||||
SameVariadicOperandSize,
|
||||
]> {
|
||||
let summary = "TorchScript prim::DictConstruct op";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchDictKeyType>:$keys,
|
||||
Variadic<AnyTorchType>:$values
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Torch_DictType:$result
|
||||
);
|
||||
|
||||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = [{
|
||||
`keys` `(` ($keys^ `:` type($keys))? `)` `values` `(` ($values^ `:` type($values))? `)` attr-dict `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
|
||||
Type getValueType() { return getType().cast<DictType>().getValueType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
||||
let summary = "TorchScript prim::GetAttr op";
|
||||
|
||||
|
|
|
@ -321,6 +321,56 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AnyType : Torch_Type<"Any", "any"> {
|
||||
let summary = "Torch any type";
|
||||
let description = [{
|
||||
Represent any torch type. All the other types are sub types of Any type.
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_NumberType : Torch_Type<"Number", "number"> {
|
||||
let summary = "Torch number type";
|
||||
let description = [{
|
||||
The Int, Float and Complex type are sub types of Number type.
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
||||
|
||||
let summary = "!torch.dict[KT, VT] ";
|
||||
let parameters = (ins "::mlir::Type":$keyType, "::mlir::Type":$valueType);
|
||||
let description = [{
|
||||
Torch Dict type with key and value type.
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
$_printer << getMnemonic() << "<" << getImpl()->keyType << ", " << getImpl()->valueType << ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
Type keyType;
|
||||
if ($_parser.parseType(keyType))
|
||||
return Type();
|
||||
if ($_parser.parseComma())
|
||||
return Type();
|
||||
Type valueType;
|
||||
if ($_parser.parseType(valueType))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, keyType, valueType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||
"::mlir::Type":$valueType), [{
|
||||
return Base::get(keyType.getContext(), keyType, valueType);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -348,17 +398,33 @@ def AnyTorchScalarType : AnyTypeOf<[
|
|||
Torch_BoolType,
|
||||
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
|
||||
|
||||
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
|
||||
// in aten/src/ATen/core/jit_type.h.
|
||||
def AnyTorchDictKeyType : AnyTypeOf<[
|
||||
Torch_AnyType,
|
||||
Torch_IntType,
|
||||
Torch_BoolType,
|
||||
Torch_FloatType,
|
||||
Torch_StringType,
|
||||
Torch_FloatType,
|
||||
AnyTorchTensorType,
|
||||
], "Allowed dict key types">;
|
||||
|
||||
// In alphabetic order.
|
||||
def AnyTorchType : AnyTypeOf<[
|
||||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Torch_TupleType,
|
||||
Torch_StringType,
|
||||
Torch_AnyType,
|
||||
Torch_DictType,
|
||||
Torch_DeviceType,
|
||||
Torch_ListType,
|
||||
Torch_LinearParamsType,
|
||||
Torch_NumberType,
|
||||
Torch_NnModuleType,
|
||||
Torch_NoneType,
|
||||
Torch_OptionalType,
|
||||
Torch_ListType,
|
||||
Torch_DeviceType,
|
||||
Torch_LinearParamsType,
|
||||
Torch_StringType,
|
||||
Torch_TupleType,
|
||||
], "Any type that is legal to pass to a Torch kernel">;
|
||||
|
||||
def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">;
|
||||
|
|
|
@ -225,3 +225,39 @@ bool npcompTypeIsATorchString(MlirType t) {
|
|||
MlirType npcompTorchStringTypeGet(MlirContext context) {
|
||||
return wrap(Torch::StringType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.any type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchAny(MlirType t) {
|
||||
return unwrap(t).isa<Torch::AnyType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchAnyTypeGet(MlirContext context) {
|
||||
return wrap(Torch::AnyType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.number type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNumber(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NumberType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNumberTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NumberType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Dict type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchDict(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DictType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
||||
}
|
||||
|
|
|
@ -93,9 +93,31 @@ static LogicalResult verify(NnModuleOp op) {
|
|||
bool isValidSubtype(Type subtype, Type type) {
|
||||
if (subtype == type)
|
||||
return true;
|
||||
|
||||
if (auto any = type.dyn_cast<AnyType>())
|
||||
return true;
|
||||
|
||||
if (auto number = type.dyn_cast<NumberType>())
|
||||
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
|
||||
|
||||
if (auto optional = type.dyn_cast<OptionalType>())
|
||||
return isValidSubtype(subtype, optional.getContainedType()) ||
|
||||
subtype.isa<Torch::NoneType>();
|
||||
|
||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
||||
if (!subtype.isa<Torch::TupleType>())
|
||||
return false;
|
||||
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
|
||||
auto types = tuple.getContainedTypes();
|
||||
if (subtypes.size() != types.size())
|
||||
return false;
|
||||
for (auto t : llvm::zip(subtypes, types)) {
|
||||
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
|
@ -142,7 +164,7 @@ static LogicalResult verify(PrimListConstructOp op) {
|
|||
auto resultType = op.getResult().getType();
|
||||
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
||||
auto matchResultElementType = [&](Type type) {
|
||||
return type.getTypeID() == resultElementType.getTypeID();
|
||||
return isValidSubtype(type, resultElementType);
|
||||
};
|
||||
if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) {
|
||||
return op.emitError() << "operand types should have the same type as the "
|
||||
|
@ -152,6 +174,27 @@ static LogicalResult verify(PrimListConstructOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimDictConstructOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(PrimDictConstructOp op) {
|
||||
auto isValidSubTypeOf = [](Type expectedType) {
|
||||
return [=](Type type) { return isValidSubtype(type, expectedType); };
|
||||
};
|
||||
|
||||
Type keyType = op.getKeyType();
|
||||
if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) {
|
||||
return op.emitError() << "keys should be of Dict key type";
|
||||
}
|
||||
|
||||
Type valueType = op.getValueType();
|
||||
if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) {
|
||||
return op.emitError() << "values should be of Dict value type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ClassTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue