diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index 04fbee12c..5991a7705 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -273,6 +273,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) { return; } + if (kind == c10::prim::dtype) { + MlirOperation operation = + createMlirOperationAtEnd(appendToBlock, "torch.prim.dtype", loc, + getMlirTypesFromValues(loc, node->outputs()), + lookupMappedValues(node->inputs())); + mapResults(node, operation); + return; + } + // Unhandled. { std::stringstream msg; diff --git a/frontends/pytorch/test/node_import/prim.py b/frontends/pytorch/test/node_import/prim.py index 4b8889cbf..25c6503ad 100644 --- a/frontends/pytorch/test/node_import/prim.py +++ b/frontends/pytorch/test/node_import/prim.py @@ -93,5 +93,14 @@ def prim_ListUnpack(l: typing.List[int]): _, val, _ = l return val +# CHECK-LABEL: func @__torch__.prim_dtype( +# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 { +# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64 +# CHECK: return %[[RET]] : i64 +@mb.import_function +@torch.jit.script +def prim_dtype(x): + return x.dtype + mb.module.operation.print() print() diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 812233e8e..25efc38cb 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -552,4 +552,13 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", []> { }]; } +def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> { + let summary = "TorchScript prim::dtype op"; + let arguments = (ins AnyTorchTensorType:$tensor); + let results = (outs AnyTorchNumberType:$result); + let assemblyFormat = [{ + $tensor attr-dict `:` type($tensor) `->` type($result) + }]; +} + #endif // TORCH_OPS