2021-02-02 09:59:42 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
|
|
|
|
import typing
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch_mlir
|
|
|
|
|
|
|
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
|
|
|
|
|
|
|
mb = torch_mlir.ModuleBuilder()
|
|
|
|
|
|
|
|
class TestModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.t1 = torch.ones(1)
|
|
|
|
self.t2 = torch.ones(1)
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
# CHECK-LABEL: func private @__torch__.TestModule.forward(
|
2021-02-18 03:28:51 +08:00
|
|
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !basicpy.NoneType {
|
2021-02-02 09:59:42 +08:00
|
|
|
def forward(self):
|
|
|
|
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
|
|
|
|
# CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]]
|
|
|
|
self.t1 = self.t2
|
2021-02-06 06:54:04 +08:00
|
|
|
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
|
|
|
|
self.callee(self.t1, self.t2)
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
# CHECK-LABEL: func private @__torch__.TestModule.callee(
|
2021-02-18 03:28:51 +08:00
|
|
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
|
2021-02-06 06:54:04 +08:00
|
|
|
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
|
|
|
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
|
|
|
|
def callee(self, x, y):
|
2021-02-02 09:59:42 +08:00
|
|
|
pass
|
|
|
|
|
|
|
|
test_module = TestModule()
|
|
|
|
recursivescriptmodule = torch.jit.script(test_module)
|
|
|
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
|
|
|
mb.import_module(recursivescriptmodule._c)
|
|
|
|
mb.module.operation.print()
|