mirror of https://github.com/llvm/torch-mlir
prim::dtype op
parent
5fed296904
commit
01b8a01e1b
|
@ -273,6 +273,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (kind == c10::prim::dtype) {
|
||||||
|
MlirOperation operation =
|
||||||
|
createMlirOperationAtEnd(appendToBlock, "torch.prim.dtype", loc,
|
||||||
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
lookupMappedValues(node->inputs()));
|
||||||
|
mapResults(node, operation);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Unhandled.
|
// Unhandled.
|
||||||
{
|
{
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
|
|
|
@ -93,5 +93,14 @@ def prim_ListUnpack(l: typing.List[int]):
|
||||||
_, val, _ = l
|
_, val, _ = l
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.prim_dtype(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
|
||||||
|
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
|
||||||
|
# CHECK: return %[[RET]] : i64
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_dtype(x):
|
||||||
|
return x.dtype
|
||||||
|
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -552,4 +552,13 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
||||||
|
let summary = "TorchScript prim::dtype op";
|
||||||
|
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||||
|
let results = (outs AnyTorchNumberType:$result);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$tensor attr-dict `:` type($tensor) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
Loading…
Reference in New Issue