torch-mlir/frontends/pytorch/test/node_import/function-derefine.py

43 lines
1.4 KiB
Python
Raw Normal View History

Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : i64 to !torch.optional<i64>
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
# CHECK: return %[[RET]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def optional_return(i: int) -> typing.Optional[int]:
return i
# CHECK-LABEL: func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> !torch.none {
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
@mb.import_function
@torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None:
return
# CHECK-LABEL: func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.none {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<i64>) -> !torch.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : i64 to !torch.optional<i64>
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<i64>) -> !torch.none
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
@mb.import_function
@torch.jit.script
def calls_optional_arg(i: int):
optional_arg(i)
mb.module.operation.print()
print()