Finish supporting importing the vast majority of `onnx` operations. This
includes:
- region support
- region value inherentance
- `torch.string` support
- `torch.list` support
- `torch.optional` support
Also note that we are in the process of proposing SparseTensorMetadata
to PyTorch FX graph export (see
https://github.com/pytorch/pytorch/pull/117907). This will hopefully
eventually replace the current data structures in torch-mlir.
There is no reason to treat `ConstantOfShape` as a specialized import
any as there exists a onnx-to-torch equivalent. Dropping the import
coding and adding support for resource conversion substantially
increases test coverage for dynamically shaped tests.
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
We can route the torch tests via `onnx` using the `torch.onnx.export`
tooling. We can then reimport, lower to torch, and compile to linalg to
validate the onnx path is working correctly.
The current implementation exposes some failures in the `onnx` path so
we cannot enable the onnx test suite yet due to segmentation faults.
This commit adds decomposition support into the core aten operators
before importing the module from torch.
Also, this commit deals with the lifted tensor constants in
torch.export.export(). We don't want to add unnecessary placeholder
nodes in the graph (extra args in the block module), and should treat
them like the constants that they are. The unnecessary clone is also
removed for max efficiency.
The investigation is largely recorded in
https://github.com/llvm/torch-mlir/pull/2881, but this change allows us
to capture non-persistent buffers that were lifted as tensor constants
(after https://github.com/pytorch/pytorch/pull/118969 landed in upstream
PyTorch), and propagate them to `Torch` dialect as "frozen"
`torch.vtensor.literal`. I believe this patch should work with both
nightly and stable PyTorch, but will let CI confirm the same. Thanks
@stellaraccident for the valuable pointers and guidance.
---------
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
Various improvements on sparsity metadata:
(1) define single data structure for all sparsity related metadata
(2) handle batched dense dimensions, as well as dense subtensor
dimensions
(3) refine sparsity propagation for deeper networks
This PR introduces a sparse_jit wrapper that can run simple models with
sparse tensor inputs end-to-end. The implementation shows all required
components on modifying sparse tensor types with a 1:N relation on the
call sites. Two tests shows that the JIT runs end-to-end while computing
the correct results.
More details to follow (generalizing to COO and different ranks, as well
as support for *output* sparse tensors), but the general concepts are
all here now.
**_Update: Thanks to Rob, bump to proper LLVM/MLIR hash is done!_**
_**NOTE that all parameter passing changes are nicely done "downstream"
in MLIR, so very little changes are required in torch-mlir code
proper**_
---------
Co-authored-by: Franz Haniel <77495327+frafranz@users.noreply.github.com>
Co-authored-by: Franz Haniel <franz.haniel@amd.com>
Adds an escape hatch from creating a DenseResourceElementsAttr for
single value tensors into DenseElementsAttr.
For 0d or 1element, splats are better as DenseElementsAttr. Don't use
DenseResourceElementsAttr for it
To handle the conversion from raw bytes to `DenseElementsAttr` we need
to handle the endianness conversion during `torch-onnx-to-torch`.
Therefore when importing `onnx.Constant` it is better to represent using
the `onnx` constant operation so that only one location requires the
endianness correction.
Note that we are waiting for actual FX traced graph support for sparse
tensors. For details see
https://github.com/pytorch/pytorch/issues/117188
Until then, however, we provide this clever importer that builds the FX
traced graph for for the dense case and then puts a sparse annotation
back on the parameters.
With import test.
Fixes https://github.com/llvm/torch-mlir/issues/2764
In the case of OPT, there are ConstantOfShape ops whose input shape is
not static (that is, an initializer), but rather comes from a Constant
op. The importer can't handle such non-static input shapes.
The fix here is to create initializers for a subset of Constant ops
(ones with "value" attributes), so that their outputs can be used
statically. Additionally, there was no case for creating a splat of
int64, so I added that as well.
---------
Co-authored-by: Dave Liddell <dliddell@xilinx.com>
Changes made during upstreaming:
* Removed comments attributing some copied code back to torch-mlir
(since it is now repatriated).
* Re-organized imports.
* Inlined RefMapping/RefTracker and TypeSubclassMap from an external
utility module.
* Added FxImporter class comments.
* Updated stack trace extraction to be fail safe.
* Added an entry-point for `import_frozen_exported_program` which uses
the shiny new upstream `torch.export.export()` API (versus the
lower-level/older API that Turbine is presently using). This
necessitated a small FX rewrite to line external state management up
with current conventions.
* Adapted one of Turbine's importer tests to go with this initial
submission. Turbine unfortunately has a lot of more-integration-ey
tests, and I would like to extract those as more of unit tests of the
importer features and upstream them that way vs trying to copy directly.
For now, one overall test with the initial submission gets us moving.
I acknowledge that there are some code quality things that could be
improved in this submission: this was authored over the course of many
months (and often via some trial and error). I would like to keep it
relatively converged with the downstream for the next few steps while
getting the test suite upstreamed. And then it will be easier to take a
hygienic pass through the code.
Including co-authors for contributors in the git log of the original
repository.
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Avinash Sharma <aviator1994@gmail.com>
Co-authored-by: Arham Khan <arhammkhan@gmail.com>
Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu>
Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
This is part 1 of 2, which will also include upstreaming the FX
importer. I started with ONNX because it forces some project layout
updates and is more self contained/easier as a first step.
Deviating somewhat from the RFCs on project layout, I made the following
decisions:
* Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks
already has opened up that namespace and it seemed to fit. Better to
have fewer things at that level.
* Setup the build so that the root project only contains MLIR Python and
pure Python deps (like the importers), but this can be augmented with
the `projects/` adding more depending on which features are enabled.
* The default build continues to build everything whereas in
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a
`torch-mlir-core` wheel with the pure contents only.
`onnx_importer.py` and `importer_smoke_test.py` are almost verbatim
copies from SHARK-Turbine. I made some minor local alterations to adapt
to paths and generalize the way they interact with the outer project. I
expect I can copy these back to Turbine verbatim from here. I also
updated the license boilerplate (they have the same license but slightly
different project norms for the headers) but retained the correct
copyright.
Other updates:
* Added the ONNX importer unit test (which also can generate test data)
in lit, conditioned on the availability of the Python `onnx` package. In
a followup once I know everything is stable, I'll add another env var
that the CI can set to always enable this so we know conclusively if
tests pass.
* Moved the ONNX conversion readme to `docs/`.
* Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` ->
`TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the
JitIR importer and LTC options `cmake_dependent_options` for robustness.