Add support for `prim::unchecked_cast`.

This arises when casting optionals, which happens a lot especially
around handling of default arguments (python `if arg is None` idiom).

In this case, the offending code for the model is in max_pool2d:
[code link](b3bf08e67f/torch/nn/functional.py (L657))
pull/176/head
Sean Silva 2021-03-01 15:26:57 -08:00
parent 939d36906f
commit df4c5764da
3 changed files with 43 additions and 0 deletions

View File

@ -206,6 +206,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::unchecked_cast) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.unchecked_cast", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;

View File

@ -5,6 +5,8 @@
import torch
import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
@ -40,5 +42,25 @@ def prim_Print(x):
def prim_RaiseException():
raise Exception("Error")
# CHECK-LABEL: func @prim_unchecked_cast(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[C3:.*]] = constant 3 : i64
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
# CHECK: scf.yield %[[C3]] : i64
# CHECK: } else {
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64
# CHECK: scf.yield %[[CASTED]] : i64
# CHECK: }
# CHECK: return %[[RESULT:.*]] : i64
@mb.import_function
@torch.jit.script
def prim_unchecked_cast(i: typing.Optional[int]):
if i is None:
return 3
return i
mb.module.operation.print()
print()

View File

@ -457,4 +457,16 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
}];
}
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
let summary = "TorchScript prim::unchecked_cast op";
// TODO: This seems to mostly be used for casting "optional" to the contained
// type. Verify that and tighten the verifier.
let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
#endif // TORCH_OPS