prim::dtype op

pull/184/head
Sean Silva 2021-03-10 16:41:18 -08:00
parent 5fed296904
commit 01b8a01e1b
3 changed files with 27 additions and 0 deletions

View File

@ -273,6 +273,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::dtype) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.dtype", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;

View File

@ -93,5 +93,14 @@ def prim_ListUnpack(l: typing.List[int]):
_, val, _ = l
return val
# CHECK-LABEL: func @__torch__.prim_dtype(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
# CHECK: return %[[RET]] : i64
@mb.import_function
@torch.jit.script
def prim_dtype(x):
return x.dtype
mb.module.operation.print()
print()

View File

@ -552,4 +552,13 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", []> {
}];
}
def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
let summary = "TorchScript prim::dtype op";
let arguments = (ins AnyTorchTensorType:$tensor);
let results = (outs AnyTorchNumberType:$result);
let assemblyFormat = [{
$tensor attr-dict `:` type($tensor) `->` type($result)
}];
}
#endif // TORCH_OPS