mirror of https://github.com/llvm/torch-mlir
[prim] Add support for List|TupleUnpack
parent
df4c5764da
commit
ca3a02da28
|
@ -215,6 +215,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::TupleUnpack) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.TupleUnpack", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::ListUnpack) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.ListUnpack", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
|
@ -62,5 +64,25 @@ def prim_unchecked_cast(i: typing.Optional[int]):
|
|||
return 3
|
||||
return i
|
||||
|
||||
# CHECK-LABEL: func @prim_TupleUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
|
||||
# CHECK: return %[[RET]]#0 : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_TupleUnpack(lt: typing.Tuple[int, int]):
|
||||
val, _ = lt
|
||||
return val
|
||||
|
||||
# CHECK-LABEL: func @prim_ListUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
|
||||
# CHECK: return %[[RET]]#1 : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_ListUnpack(lt: typing.List[int]):
|
||||
_, val, _ = lt
|
||||
return val
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -469,4 +469,26 @@ def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", []> {
|
||||
let summary = "TorchScript prim::ListUnpack op";
|
||||
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimTupleUnpackOp: Torch_Op<"prim.TupleUnpack", []> {
|
||||
let summary = "TorchScript prim::TupleUnpack op";
|
||||
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($results)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
Loading…
Reference in New Issue