From ca3a02da28a805a7046b6fc3e63f630d1eedf547 Mon Sep 17 00:00:00 2001 From: Bryce Arden Date: Tue, 2 Mar 2021 16:39:48 -0600 Subject: [PATCH] [prim] Add support for List|TupleUnpack --- .../pytorch/csrc/builder/node_importer.cpp | 18 +++++++++++++++ frontends/pytorch/test/node_import/prim.py | 22 +++++++++++++++++++ include/npcomp/Dialect/Torch/IR/TorchOps.td | 22 +++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index de5771baa..920550382 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -215,6 +215,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) { return; } + if (kind == c10::prim::TupleUnpack) { + MlirOperation operation = + createMlirOperationAtEnd(appendToBlock, "torch.prim.TupleUnpack", loc, + getMlirTypesFromValues(loc, node->outputs()), + lookupMappedValues(node->inputs())); + mapResults(node, operation); + return; + } + + if (kind == c10::prim::ListUnpack) { + MlirOperation operation = + createMlirOperationAtEnd(appendToBlock, "torch.prim.ListUnpack", loc, + getMlirTypesFromValues(loc, node->outputs()), + lookupMappedValues(node->inputs())); + mapResults(node, operation); + return; + } + // Unhandled. { std::stringstream msg; diff --git a/frontends/pytorch/test/node_import/prim.py b/frontends/pytorch/test/node_import/prim.py index a4bef107d..845d388bf 100644 --- a/frontends/pytorch/test/node_import/prim.py +++ b/frontends/pytorch/test/node_import/prim.py @@ -2,6 +2,8 @@ # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. +import typing + import torch import torch_mlir @@ -62,5 +64,25 @@ def prim_unchecked_cast(i: typing.Optional[int]): return 3 return i +# CHECK-LABEL: func @prim_TupleUnpack( +# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 { +# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64 +# CHECK: return %[[RET]]#0 : i64 +@mb.import_function +@torch.jit.script +def prim_TupleUnpack(lt: typing.Tuple[int, int]): + val, _ = lt + return val + +# CHECK-LABEL: func @prim_ListUnpack( +# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 { +# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64 +# CHECK: return %[[RET]]#1 : i64 +@mb.import_function +@torch.jit.script +def prim_ListUnpack(lt: typing.List[int]): + _, val, _ = lt + return val + mb.module.operation.print() print() diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 7627340de..958be63d3 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -469,4 +469,26 @@ def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> { }]; } +def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", []> { + let summary = "TorchScript prim::ListUnpack op"; + + let arguments = (ins AnyTorchType:$operand); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($results) + }]; +} + +def Torch_PrimTupleUnpackOp: Torch_Op<"prim.TupleUnpack", []> { + let summary = "TorchScript prim::TupleUnpack op"; + + let arguments = (ins AnyTorchType:$operand); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($results) + }]; +} + #endif // TORCH_OPS