2021-02-26 08:35:29 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch_mlir
|
|
|
|
|
2021-03-02 07:26:57 +08:00
|
|
|
import typing
|
|
|
|
|
2021-02-26 08:35:29 +08:00
|
|
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
|
|
|
|
|
|
|
mb = torch_mlir.ModuleBuilder()
|
|
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: func @prim_NumToTensor(
|
|
|
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
|
|
|
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
# CHECK: }
|
|
|
|
|
|
|
|
@mb.import_function
|
|
|
|
@torch.jit.script
|
|
|
|
def prim_NumToTensor(i: int):
|
|
|
|
return _to_tensor(i)
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
# CHECK-LABEL: func @prim_Print(
|
|
|
|
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
|
|
|
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
|
|
|
|
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
@mb.import_function
|
|
|
|
@torch.jit.script
|
|
|
|
def prim_Print(x):
|
|
|
|
print("x", x)
|
|
|
|
|
2021-03-02 05:47:50 +08:00
|
|
|
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
|
|
|
|
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
|
|
|
|
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
|
|
|
|
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
|
|
|
|
# CHECK: return %[[NONE]] : !basicpy.NoneType
|
|
|
|
@mb.import_function
|
|
|
|
@torch.jit.script
|
|
|
|
def prim_RaiseException():
|
|
|
|
raise Exception("Error")
|
|
|
|
|
2021-03-02 07:26:57 +08:00
|
|
|
# CHECK-LABEL: func @prim_unchecked_cast(
|
|
|
|
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
|
|
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
|
|
|
# CHECK: %[[C3:.*]] = constant 3 : i64
|
|
|
|
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
|
|
|
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
|
|
|
|
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
|
|
|
# CHECK: scf.yield %[[C3]] : i64
|
|
|
|
# CHECK: } else {
|
|
|
|
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64
|
|
|
|
# CHECK: scf.yield %[[CASTED]] : i64
|
|
|
|
# CHECK: }
|
|
|
|
# CHECK: return %[[RESULT:.*]] : i64
|
|
|
|
@mb.import_function
|
|
|
|
@torch.jit.script
|
|
|
|
def prim_unchecked_cast(i: typing.Optional[int]):
|
|
|
|
if i is None:
|
|
|
|
return 3
|
|
|
|
return i
|
|
|
|
|
2021-02-26 08:35:29 +08:00
|
|
|
mb.module.operation.print()
|
|
|
|
print()
|