mirror of https://github.com/llvm/torch-mlir
Add support for `prim::unchecked_cast`.
This arises when casting optionals, which happens a lot especially
around handling of default arguments (python `if arg is None` idiom).
In this case, the offending code for the model is in max_pool2d:
[code link](b3bf08e67f/torch/nn/functional.py (L657)
)
pull/176/head
parent
939d36906f
commit
df4c5764da
|
@ -206,6 +206,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::unchecked_cast) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.unchecked_cast", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import typing
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
@ -40,5 +42,25 @@ def prim_Print(x):
|
|||
def prim_RaiseException():
|
||||
raise Exception("Error")
|
||||
|
||||
# CHECK-LABEL: func @prim_unchecked_cast(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
||||
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
||||
# CHECK: scf.yield %[[C3]] : i64
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64
|
||||
# CHECK: scf.yield %[[CASTED]] : i64
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RESULT:.*]] : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_unchecked_cast(i: typing.Optional[int]):
|
||||
if i is None:
|
||||
return 3
|
||||
return i
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -457,4 +457,16 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
|
||||
let summary = "TorchScript prim::unchecked_cast op";
|
||||
// TODO: This seems to mostly be used for casting "optional" to the contained
|
||||
// type. Verify that and tighten the verifier.
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
Loading…
Reference in New Issue