mirror of https://github.com/llvm/torch-mlir
prim::dtype op
parent
5fed296904
commit
01b8a01e1b
|
@ -273,6 +273,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::dtype) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.dtype", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -93,5 +93,14 @@ def prim_ListUnpack(l: typing.List[int]):
|
|||
_, val, _ = l
|
||||
return val
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_dtype(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_dtype(x):
|
||||
return x.dtype
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -552,4 +552,13 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
||||
let summary = "TorchScript prim::dtype op";
|
||||
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||
let results = (outs AnyTorchNumberType:$result);
|
||||
let assemblyFormat = [{
|
||||
$tensor attr-dict `:` type($tensor) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
Loading…
Reference in New Issue