[prim] Add support for List|TupleUnpack

pull/177/head
Bryce Arden 2021-03-02 16:39:48 -06:00 committed by Sean Silva
parent df4c5764da
commit ca3a02da28
3 changed files with 62 additions and 0 deletions

View File

@ -215,6 +215,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::TupleUnpack) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.TupleUnpack", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::ListUnpack) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.ListUnpack", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;

View File

@ -2,6 +2,8 @@
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
@ -62,5 +64,25 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return 3
return i
# CHECK-LABEL: func @prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
# CHECK: return %[[RET]]#0 : i64
@mb.import_function
@torch.jit.script
def prim_TupleUnpack(lt: typing.Tuple[int, int]):
val, _ = lt
return val
# CHECK-LABEL: func @prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
# CHECK: return %[[RET]]#1 : i64
@mb.import_function
@torch.jit.script
def prim_ListUnpack(lt: typing.List[int]):
_, val, _ = lt
return val
mb.module.operation.print()
print()

View File

@ -469,4 +469,26 @@ def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
}];
}
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", []> {
let summary = "TorchScript prim::ListUnpack op";
let arguments = (ins AnyTorchType:$operand);
let results = (outs Variadic<AnyTorchType>:$results);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($results)
}];
}
def Torch_PrimTupleUnpackOp: Torch_Op<"prim.TupleUnpack", []> {
let summary = "TorchScript prim::TupleUnpack op";
let arguments = (ins AnyTorchType:$operand);
let results = (outs Variadic<AnyTorchType>:$results);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($results)
}];
}
#endif // TORCH_OPS