Commit Graph

72 Commits (43dba03afdd82bbd0ac7d1d06596cf13f6daed6b)

Author SHA1 Message Date
Sean Silva 43dba03afd Properly model "derefinement".
In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.

We introduce a new op `torch.derefine` that models that impedance
mismatch. This op allows casting a value from one type to a type that it
is a subtype of to model this behavior.

Recommended review order:
- TorchOps.td for new torch.derefine (and updated docs for
  `torch.prim.unchecked_cast`)
- new test code in if.py, loop.py, function-derefine.py
- new code in node_importer.cpp for handling derefinement insertion
- function_importer.cpp and utils changes in torch_to_mlir_utils.cpp

Properly handling derefinement on function boundaries required
relayering the code so that graph_importer.cpp/.h is now
function_importer.cpp/.h because only the `torch::jit::Function`
(actually the `c10::FunctionSchema` it holds) knows the derefined types that are
actually needed at the boundary (see `function-derefine.py` for a test).

Annoyingly, this churns all the functions which are now prefixed with
`__torch__.` but that is more correct anyway (that is their linkage name
in the `torch::jit::CompilationUnit`; the previous `mb.import_function`
was actually buggy in the case of functions calling each other as it
would reference their unqualified name).

With this change, we can import `resnet18` from `torchvision` :)
IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-03 15:09:44 -08:00
Bryce Arden 1736ff0253 [prim] Add TupleIndex support
I could not find a corresponding ListIndex in prim, which seems to
translate to a __get_attr__ under the hood. I think the reason a tuple
Index op can exist is because Tuple's are supposed to be frozen, where
List operands can be mutable.
2021-03-02 17:28:32 -08:00
Bryce Arden 68338eafb7 [chore] Make variable names in prim.py more clear 2021-03-02 17:28:32 -08:00
Bryce Arden ca3a02da28 [prim] Add support for List|TupleUnpack 2021-03-02 17:28:32 -08:00
Sean Silva df4c5764da Add support for `prim::unchecked_cast`.
This arises when casting optionals, which happens a lot especially
around handling of default arguments (python `if arg is None` idiom).

In this case, the offending code for the model is in max_pool2d:
[code link](b3bf08e67f/torch/nn/functional.py (L657))
2021-03-02 16:01:34 -08:00
Sean Silva 939d36906f Add support for prim::Loop op.
This is a funny one. It combines a `for` and `while` loop in one op. We
will need to write some conversions to `scf`.
2021-03-02 16:01:34 -08:00
Sean Silva 7dfd6f697e Add support for prim::RaiseException.
Used by resnet18.

It seems to originate from a helper `_verify_batch_size`:
[code link](b3bf08e67f/torch/nn/functional.py (L2099)).

I couldn't find a way to test `prim::RaiseException` without also having
`prim::Uninitialized`.
2021-03-02 16:01:34 -08:00
Sean Silva c837dbb077 Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).

Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff

The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.

Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:

```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```

That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](1d6bd15790/torch/csrc/jit/runtime/interpreter.cpp (L937)).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-03-01 12:08:01 -08:00
Sean Silva 59a3f46795 Add support for prim.NumToTensor
With this, we can import BERT!
```
pt_util ~/tmp/bert.pt  --import --exported-name=forward \
| npcomp-opt -torch-globalize-object-graph -inline -symbol-dce
```
https://gist.github.com/silvasean/fe7735ff5d065cc9216f7b0346d0e977

The test case here is a bit unconventional -- it isn't actually valid
Python. To figure out how to generate it I had to go search the PyTorch
codebase for "NumToTensor" and work backward. In this case I found
this
[code](649760e5f1/torch/csrc/jit/frontend/ir_emitter.cpp (L464))
which via a wild guess I was able to turn into a test case.

In this case it didn't take me too long, but when doing this kind of
"add a bunch of trivial stuff to bring up a real model", I'm starting to
think that we might skimp on test cases when it's fairly trivial and not
obvious how to test with a small test.
2021-02-26 10:16:56 -08:00
Sean Silva 7b6fa27838 Rename tests to match the code they test
- `module_import -> ivalue_import`, as it mainly tests ivalue_importer.cpp
- `graph_import -> node_import`, as it mainly tests node_importer.cpp
 - graph_importer.cpp does call into node_importer.cpp, but doesn't do
 much.

This was getting pretty confusing. Also add README.md's in each
directory for more clarity.
2021-02-25 13:31:33 -08:00
Bryce Arden 27a4515de2
Add Conv2D Torchscript Import Support (#167)
Adds support for lowering a torch.nn.Conv2d module to the Torch Dialect through TorchScript import.
Generated IR can be viewed here:
https://gist.github.com/brycearden/6c0f790115c4577249372ef82768e6fd

Required implementing support for tuple in the ivalue importer and list in the node importer.
2021-02-25 12:14:00 -08:00
Sean Silva a375ccf9da Add ability to annotate TorchScript classes.
The first use case is to annotate certain program constructs as either
exported or private. In this commit we plumb it down to
GlobalizeObjectGraph which makes use of this information.

Recommended review order:
1. class_annotator.h/.cpp + `test/module_import/annotations/*`
    - New abstractions to communicate with Python code and annotate.
2. IR changes in TorchOps.td
    - Adding "private" attribute to various things.
3. ivalue_import.cpp changes
    - Module + ClassAnnotator = annotated IR
4. GlobalizeObjectGraph.cpp + tests
    - use new "private" attributes to create "private" IR.
    - also, tweak some of the op deleting mechanics, which was triggering
      some memory errors / assertions

With this, we can run the classifier through and inline it as follows:
```
frontends/pytorch/utils/pt_util.py --import --exported-name forward ~/tmp/classifier.pt \
| npcomp-opt -torch-globalize-object-graph -inline
```
IR: https://gist.github.com/silvasean/32dcad9f6270557f412094a77cecdd69
2021-02-25 11:28:34 -08:00
Sean Silva 8486968925 Add trivial inliner interfaces.
With this + manually setting private visibility on everything, a simple
classifier can be reduced to this IR, which is looking pretty lean and
mean:
https://gist.github.com/silvasean/19e7e2e21a61ff197aeac0dd864d188f

Also, include a utility script for importing `.pt` models.

```
pt_util.py --import classifier.pt | npcomp-opt -torch-globalize-object-graph
```
2021-02-22 10:40:38 -08:00
Sean Silva 1b769f7841 Extend GlobalizeObjectGraph to handle torch.prim.GetAttr returning NnModuleType
This happens in practice. With this, we can globalize slots for the
non-trivial classifier layer obtained from
https://github.com/NVIDIA/NeMo/blob/main/tutorials/nlp/Joint_Intent_and_Slot_Classification.ipynb

This also adds support for tuple return types, which were needed by that
model.
2021-02-19 10:23:25 -08:00
Sean Silva 158c5c484d Implement GlobalizeObjectGraph transformation.
This required restructuring of how we model TorchScript on import. The
main difference is that now we split out a `torch.class_type` that holds
methods and declarations of the types of each slot. This is more
consistent with TorchScript (our previous representation was
"denormalized").

Recommended reading order:
1. check out the description of `torch.class_type` in `TorchOps.td` and
   look at `test/Dialect/Torch/ops.mlir` and
   `frontends/pytorch/test/module_import/` to familiarize with the new
   representation.
   - Just look at the new IR. The diff between the old names and new
     names is confusing.
2. check out `test/Dialect/Torch/globalize-object-graph*.mlir`
   and read along with the pass description in
   `include/npcomp/Dialect/Torch/Transforms/Passes.td`
3. Read the code in `GlobalizeObjectGraph.cpp` and miscellaneous changes
   in `ivalue_importer.cpp`, `TorchOps.cpp`, etc.
2021-02-18 18:18:47 -08:00
Bairen Yi 99d1db18d2 Add NoneType support for ivalue_importer
PyTorch added a Global variable `_is_full_backward_hook` recently.

See https://github.com/pytorch/pytorch/pull/46163

Signed-off-by: Bairen Yi <yibairen.byron@bytedance.com>
2021-02-18 11:20:29 -08:00
Stanley Winata a38b7b72b2 adapt acap_dispatch to latest pytorch nightly ("1.9.0.dev20210215+cpu")
Modify ACAP_Dispatch to work with latest pytorch
-Remove boxed from convolution's m.impl
-Use redispatch and constrainted keyset to replace deprecated
callwithdispatchkey
2021-02-18 11:13:29 -08:00
Sean Silva 498979ad28 Add MLIR diagnostic handler that prints to `sys.stderr`.
This is needed so that output shows up properly in a Jupyter notebook.
2021-02-17 18:50:05 -08:00
Sean Silva 572163dfde Handle object identity correctly.
This required some careful considerations when defining object identity
for tensors. See the comments for how we do it.

This also tracks some basic information for diagnostics.
2021-02-10 15:15:56 -08:00
Sean Silva 7f7bf39551 Add prim::Print and fix prim::CallMethod
For now, we are treating strings as bytes.
2021-02-10 15:15:56 -08:00
Sean Silva c4e4a11e3f Add support for prim::GetAttr/SetAttr/CallMethod/If
This required some invasive surgery to graph_importer.h/cpp,
specifically moving most of it into node_importer.h/cpp and relayering
it. The abstraction that it had didn't work well in the recursive
setting that happens with prim::If.

The key observation is that torch::jit::Graph doesn't really correspond
directly to anything on the MLIR side. It's a weird combination of a
context, builder, and function and just holds a `torch::jit::Block`. It
is `torch::jit::Node` and `torch::jit::Block` which form the recursive
structure analogous to MLIR's operation/region/block. So
node_importer.h/cpp makes sense as a core building block.

As part of doing this, I did venture a bit into the AcapController code,
and realize now that there is functionality duplicated there with the
ivalue importer. Will refactor that soon.
2021-02-04 17:01:47 -08:00
Sean Silva 99b845411d Rename some tests for consistency 2021-02-01 17:01:18 -08:00
Sean Silva 689b40c7a6 Add initial TorchScript module importer
It turns out that this was easiest to structure as a general IValue
importer, since torch module are just one of the possible IValue's.

We import the IValue object graph in a braindead fashion into basicpy
ops and a new `torch.nn_module` op that is used to model the
attributes/methods of a torch::jit::Module IValue. See `Torch/ops.mlir`
for an example, and also check out the .py import tests in
`frontends/pytorch/test/module_import`.

As part of this change, a few housekeeping tasks:
- extract some helpers from graph_importer.cpp
- more helpers around the C API
- misc touchups
2021-01-28 11:55:17 -08:00
mikeurbach 0f6a65a1c5
Enable building using LLVM_EXTERNAL_PROJECTS. (#152)
This allows building NPCOMP as an external project of LLVM, similar to
how CIRCT can be built: https://github.com/llvm/circt/pull/227.

The CMake options to use this build style look like this:

```
  -DLLVM_EXTERNAL_PROJECTS=npcomp \
  -DLLVM_EXTERNAL_NPCOMP_SOURCE_DIR=/path/to/mlir-npcomp \
```
2021-01-26 11:43:43 -07:00
Sean Silva b92d724179 NFC: Rename "graph_export" to "graph_import"
These mainly exercise the `module_builder.import_function` function, so
it makes sense for the directory to be called "graph import".
2021-01-21 12:17:21 -08:00
Sean Silva d818043986 Bump llvm-project to d50d7c37a159802c89454a6c53c0ec2e7949d84a
Fixes:
- use `op->(method on Operation)`
- update for MlirIdentifier in signature of mlirNamedAttributeGet
2020-12-14 14:30:51 -08:00
Stella Laurenzo f6d7ee06ef Make torch_mlir compatible with binary PyTorch installations.
* This has been anticipated for a long time in that it is quite hard to keep C++ binary compatibility across a system landscape as diverse as PyTorch, LLVM, and this project. This is why we based the PyTorch extension on the MLIR and NPCOMP C APIs only: that is the only sane linkage story for the entire matrix.
* Removes the few LLVM'isms in torch_mlir that had snuck in, using either STL or PyTorch support utilities. The new rule here is that LLVM C++ includes are forbidden at this level and (as stated in the design), torch_mlir should use the PyTorch runtime and support libraries (not introduce an incidental C++ dependency on LLVM).
* Also deletes mnist-playground as it was proving impossible to keep the grid of PyTorch vs system ABI divisions functioning. I am open to a less drastic course here (optional/disabled by default?)
* This gets us pretty close to just using PyTorch's extension builder API, which will be nice for distribution (i.e. it integrates well with the PyTorch ecosystem for deployment). I ended up just simplifying the in-tree CMake support for now.
* Fixes #138
2020-12-14 09:51:00 -08:00
Sean Silva b2077738ca Bump llvm-project to 444822d77a7fea28aa49edf24533c987efa1b2ee
Fixes:
- renames StandardTypes -> BuiltinTypes
- std.extract_element -> tensor.extract
2020-12-11 14:43:38 -08:00
Phoenix Meadowlark 699bf5df45
Add cos_e2e.py, test_utils and support for tensor inputs (#134) 2020-11-24 19:02:50 -08:00
Stella Laurenzo e2405e3ca8 Add design sketch for aten fallback. 2020-11-24 18:13:35 -08:00
Stella Laurenzo 3937dd14cb Add basicpy.numeric_constant op.
* Going through TODOs on the PyTorch side, this is a big cause of them (not being able to have constants for signed/unsigned).
* Added complex while in here since we're at the phase where it is better to just have things complete than partially done.
2020-11-24 16:44:40 -08:00
Stella Laurenzo b0623b7793 Bump LLVM version to 4f5355ee73626f8b8fe6bf0dd6d167fea7628a2c.
* Incorporates changes around LLVM StringRef.
* Ports fix in upstream pybind11 detection.
* Disables CI hack due to broken pybind detection.
2020-11-24 13:12:04 -08:00
meadowlark@google.com 959c0a79cb Expand pytype coverage for torch_signature_ods_gen.py 2020-11-24 12:42:32 -08:00
Stella Laurenzo f13994fdf7 NFC: Remove TODO about creating an mlirOperationStateDestroy (unnecessary). 2020-11-23 15:01:51 -08:00
Stella Laurenzo 9ffd2556ab Add TorchScript import tests missed in previous change. 2020-11-23 14:43:42 -08:00
Stella Laurenzo 78a3c90758 Add TorchScript graph importer.
* Does not handle all features yet but should conservatively fail on unsupported things.
* Location tracking is still somewhat mismatched between what TorchScript and MLIR do. Likely need a better heuristic for tracking locations from defs for nodes that do not carry location.
* Sets the ground-work for a specialized/generic split but only implements the generic side.
* Had some evidence that this requires a recent bump of PT nightly (within the last month) to pick up pybind11 2.6, which includes some cross-module symbol fixes (vs the previously sync'd version). No source changes, but older versions fail to cast function types at runtime.
2020-11-23 14:20:09 -08:00
Sean Silva 1dfcfa9cd1 Add aten.mm op and "test" it e2e.
Note that unlike aten.matmul which has dynamic behavior
depending on the argument ranks (can do matrix-matrix, matrix-vector,
batch matmul, etc.), aten.mm is just a vanilla matrix
multiply, which can be lowered precisely to tcf.matmul.

The "test" is really just an example that I stared at while getting my
feet wet with this. We probably want something that actually tests this
as part of `ninja check-npcomp`.
2020-11-20 17:21:24 -08:00
harsh-nod 67d6694fdc
Update PYTHON cmake variables to Python3 (#121)
After the recent change of cmake variables
from PYTHON_INCLUDE_DIRS to Python3_INCLUDE_DIRS
and PYTHON_LIBRARIES to Python3_LIBRARIES, there were
a few files that still had references to the old
variables. This patch fixes that.
2020-11-17 16:04:14 -08:00
Stella Laurenzo 6850295ec5 Teach cmake how to find the installed PyTorch.
* In most situations, this eliminates the need to explicitly set a path to the Torch cmake files.
* Also upgrades to new Python3 find package. (should eliminate 2.x mismatches)
* Since PyTorch is located by asking Python where it is, this eliminates a lot of causes of mismatch. (one source of truth)
2020-11-13 17:19:25 -08:00
Stella Laurenzo 47ac80491c Delete old PyTorch 1.3 type dispatch oriented code paths.
* We aren't quite at e2e parity, but we aren't going back and the old path is bit-rotted.
2020-11-12 22:27:05 -08:00
Stella Laurenzo e359167562 Fix dispatch of arange.
* Fixes #107
* I wouldn't say I love what had to be done here. Worth a conversation with the PT devs (probably as part of a rollup of a bunch of this stuff).
2020-11-12 22:07:23 -08:00
Stella Laurenzo b4c7ae1e0c Repurpose numpy-compiler compiler/runtime flow for PyTorch.
* A bit gross because I took the chance to upgrade all of the backend bits to the new MLIR Python bindings and we still co-mingle the old and new for now.
* Since the Python created PassManagers are configured for explicit nesting, I had to upgrade some of the pass pipelines to be explicit.
* The demo in mul_maximum_e2e.py now compiles, runs through PyTorch and through the JIT, prints and asserts the same results.
* I am not claiming that this is the prettiest API in this patch: consider that this is just directly using low-level APIs and there should be an intervening high level API.
2020-11-11 10:38:13 -08:00
Stella Laurenzo e60dc2470e Add aten.maximum op and conversions from aten->tcf.
* Conversions are very simple, suporting mul, maximum and add (alpha=1 only).
* Example added with pass pipeline needed to run.
* Much missing off of the golden path but sufficient for such simple cases.
2020-11-04 17:20:54 -08:00
Stella Laurenzo 6c702b149f Add a number of kernels and new patterns.
* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_
* Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args.
* The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen.
* The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with.
* More progress on #97
2020-11-04 14:36:59 -08:00
Harsh Menon c2d3820e48 Fix insertion point bug #102
The current code was inserting all build_list ops
after the last constant op since it was assuming that all
elements being passed in were constants.

This patch replaces that patch with a new function that
inserts the build_list ops before the terminator.

Also modifies test_export_conv2d_fwd.py since its output
no longer matches.

TEST: Added test_export_cat.py which is the code in #102
2020-11-02 16:41:26 -08:00
Stella Laurenzo 0c73c535d6 Capture backward conv and copy_ kernels.
* This is sufficient to capture the forward and backward pass and gradients of a convolutional model with an nllloss.
* As with the forward conv, the backward conv is a special case wrapped in an enigma on the PyTorch side. There aren't many like it, so special casing is just what we do.
* When I traced this, I found that the copy_ op is not yet boxing compatible so I had to map it manually. If there are many more like this, I'll probably do something a bit more clever to reduce duplication.
* This exposes new signature patterns that will need to be handled by the ATen lowering. Will take care of that next: It will be nice to have an e2e of a non-trivial case with full gradients.
* Fixes #97.
2020-10-30 22:59:26 -07:00
Stella Laurenzo 8d98dd4551 Support optional args/returns and other odds and ends.
* None's out Device? args.
* Emits bool tensors if needed.
* Adds some stderr tracing to better see what is going on.
* Test case that exercises NLLLoss.
* This test case emits something for backward calculations but there are some issues still to be worked out, so that part is left out of the test case.
* Progress on #97
2020-10-30 14:50:28 -07:00
Stella Laurenzo c08935a418 Rewrite ATen ODS code generator to be based on new op registry and new signature recognition system.
* Deletes prior code generator from previous attempt (moved some of it into this one).
* Renames old generated tablegen source to "Legacy".
* Generates ODS and import rules for most binary and unary arithmetic ops.
* Removes old generated ops and integration tests that were testing details of the prior setup.
2020-10-28 10:37:37 -07:00
Stella Laurenzo 510f226df2 Expose signature metadata to ops and implement ATenRecognizeKernelsPass pass.
* Two op interfaces, one for querying instance metadata and one for getting static data needed to construct an op from a generic form.
* For torch.generic_kernel ops, metadata is splatted in during capture from Torch (it comes from the op registry, which will work for either device capture or graph import).
* Moved the 'add' out of the generated set so I can experiment on it. It implements the TorchBuildableKernelOpInterface interface which provides its metadata.
* The ATenRecognizeKernelsPass pass generically lowers from a torch.generic_kernel to recognized ops that implement the TorchBuildableKernelOpInterface, handling the various types of transformations that we allow at this stage.
2020-10-26 20:31:45 -07:00
Stella Laurenzo 91fc83d2e7 NFC: Transition ATen passes to tablegen registration. 2020-10-22 17:12:44 -07:00