Commit Graph

1130 Commits (874fdb7e429175b701602e08df027f756bdf6ba9)
 

Author SHA1 Message Date
powderluv 0f751498a7
Update releaseSnapshotPackage.yml 2022-04-22 15:38:36 -07:00
powderluv d789aee11e
Only upload torch*.whl (#786)
only upload torch*.whl to unblock OSX build failures during upload. We have to move to svenstaro/upload-release-action
2022-04-22 15:17:09 -07:00
Prashant Kumar e9c785b04b Generate backward graph via functorch-aot module
Example to demonstrate the extraction of forward as well as
backward graph via Functorch's AOT module is added.
2022-04-22 20:58:35 +05:30
Ashay Rane 28bf9cc1fc doc: [nfc] add instruction for running Python regression tests
Prior to this patch, the top-level README did not include the line for
running the Python regression tests in `//python/test`.  This patch
fixes the problem by adding a line to run the `check-torch-mlir-python`
target.
2022-04-22 10:54:04 -04:00
powderluv cbf158f069
Update buildRelease.yml
Update artifact directory to ./build_tools/python_deploy/wheelhouse/*.whl
2022-04-21 19:57:27 -07:00
Prashant Kumar 5cdef0213d [LINALG] Bug fix i64 vs i32 type comparison.
Comparing index type instead of integer types solves the problem.
2022-04-22 08:09:58 +05:30
powderluv 9f2184da98
Update oneshotSnapshotPackage.yml
remove now deprecated inputs to build and test
2022-04-21 19:12:42 -07:00
powderluv 8003b92fa7
Delete releasePackage.yml 2022-04-21 18:54:01 -07:00
powderluv c1026fa95b
Switch to using the new Release builds (#780) 2022-04-21 18:46:34 -07:00
powderluv 4ef61aa27f
Minor buildsystem fixes (#778)
Sets up auto-pinning of latest torch-nightly
2022-04-21 15:53:00 -07:00
powderluv 0257d91a21
Update buildManylinux.yml
use sudo for mac OS
2022-04-21 11:06:02 -07:00
powderluv 299c1bbe6d
Update buildManylinux.yml
fix build naming
2022-04-21 10:55:40 -07:00
powderluv b03eac4224
Enable OSX (Intel, Apple Silicon Builds) (#776)
Update pinned pytorch version. Will submit a follow on PR to bump.
Also update artifacts directory
2022-04-21 10:47:28 -07:00
powderluv cc3a4a58ef
Add oneshot release snapshot for test/ondemand (#768)
* Add oneshot release snapshot for test/ondemand

Add some build scripts to test new release flow based on IREE.
Wont affect current builds, once this works well we can plumb it
in.

Build with manylinux docker

* Fixes a few issues found when debugging powderluv's setup.

* It is optional to link against Python3_LIBRARIES. Check that and don't do it if they don't exist for this config.
* Clean and auditwheel need to operate on sanitized package names. So "torch_mlir" vs "torch-mlir".
* Adds a pyproject.toml file that pins the build dependencies needed to detect both Torch and Python (the MLIR Python build was failing to detect because Numpy wasn't in the pip venv).
* Commented out auditwheel: These wheels are not PyPi compliant since they weak link to libtorch at runtime. However, they should be fine to deploy to users.
* Adds the --extra-index-url to the pip wheel command, allowing PyTorch to be found.
* Hack setup.py to remove the _mlir_libs dir before building. This keeps back-to-back versions from accumulating in the wheels for subsequent versions. IREE has a more principled way of doing this, but what I have here should work.

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2022-04-21 02:19:12 -07:00
Prashant Kumar 33c9d256ea [REFBACKEND] Add support for returning multiple different return types.
Added the dynamic registration of return function to the execution
engine. This makes sure that  different/multiple return types are supported.
Also, updated the .style.yapf indentation to 4.
2022-04-21 09:02:30 +05:30
Sean Silva b69db60f85 Pin the Python package to the exact PyTorch nightly.
This avoids issues where PyTorch version drift has made things
incompatible.

One caveat is that you will need to specify
`-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre` on the command line for pip to know where to find the nightly
packages (there is no way around this) -- this is easiest to do by
simultaneously passing `-r requirements.txt` on the pip command line.
2022-04-20 16:47:38 -07:00
Sean Silva 712b78c674
Change preferred style to be PEP8 2022-04-20 14:38:19 -07:00
Sean Silva 075464fa74 Add a new `torch_mlir.compile` method.
This makes it much easier to convert models and hides all the
ClassAnnotator complexity.

This also adds a new example `torchscript_resnet18_all_output_types.py`
which shows the ResNet18 IR for all output types.

Also,

- This moves `run_pipeline_with_repro_report` to
  `torch_mlir.compiler_utils`.
2022-04-20 10:06:01 -07:00
Clément Fournier 578d0ec292 Review comments 2022-04-19 15:11:17 -07:00
Clément Fournier 3e0c1cf6af Change cache suffix to not invalidate existing caches 2022-04-19 15:11:17 -07:00
Clément Fournier 566650c5ae Use distinct ccaches
Since they run in distinct jobs, using the same ccache would
cause one job to overwrite the cache of the other.

See https://github.com/ljfitz/torch-mlir/pull/16 for a proof
that this works. The first build takes a long time but ccache
takes over in the dummy commit.
2022-04-19 15:11:17 -07:00
Clément Fournier 8d700dee21 Improve README 2022-04-19 15:11:17 -07:00
Clément Fournier f9d5201ae6 address PR review 2022-04-19 15:11:17 -07:00
Clément Fournier 4a2535a86d Add build-out-of-tree job 2022-04-19 15:11:17 -07:00
Clément Fournier 37087ccd5f Refactor current CI workflow into composable jobs 2022-04-19 15:11:17 -07:00
Clément Fournier 2a0c567418 Add README instructions for OOT build 2022-04-19 15:11:17 -07:00
Sean Silva 3b5310d6d2 Move COMMON_TORCH_MLIR_LOWERING_XFAILS into test_suite
That way, downstreams don't have to duplicate this list.

Also, remove "external config" feature, since it is subsumed by just
importing the test suite.
2022-04-19 14:32:58 -07:00
Vivek Khandelwal 769f3a8870 [MLIR][TORCH] Add E2E support for max_pool2d_with_indices op
This commit adds lowering of `max_pool2d_with_indices` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-18 21:05:19 +05:30
Ashay Rane d3c08376af
test: add end-to-end test for aten.neg (#760) 2022-04-15 12:37:57 -07:00
Ashay Rane a893c7d5cf
Add shape transfer function and lowering to linalg for aten.neg (#759)
* shape: add shape transfer function for aten.neg

Prior to this patch, the list of shape transfer functions did not
include `aten.neg`, which resulted in errors like below.

```
error: unsupported by backend lowering: tensor with unknown rank or dtype
note: see current operation: %0 = "torch.aten.neg"(%arg0) :
  (!torch.vtensor<[256,256],f32>) -> !torch.vtensor<*,f32>
note: this is likely due to a missing shape transfer function in shape_lib_gen.py
```

This patch fixes the problem by adding a shape transfer function to
reflect the point-wise nature of this operation.

* linalg: add translation of aten.neg operation

This patch adds a translation rule to lower `aten.neg` operations on
tensors to an `arith.negf` operation wrapped inside a `linalg.generic`
operation.  This patch also adds a rudimentary test.
2022-04-15 11:11:22 -07:00
Vivek Khandelwal 1bccb4fc8a [MLIR][TORCH] Add E2E support for aten::max_pool2d_with_indices_backward op
This commit adds lowering of `aten::max_pool2d_with_indices_backward` op.

This commit also fixes formatting issues in basic.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-14 21:46:47 +05:30
powderluv 91d3e7ba15 Remove CCACHE settings and validate on OSX
Builds whl package for OSX. Need to validate smoke tests next
2022-04-14 01:32:49 -07:00
Maksim Levental 24f9de7120
Fixes https://github.com/llvm/torch-mlir/issues/751 where `torch.bool` is parsed as signless `i1`. (#752) 2022-04-13 12:28:27 -05:00
Maksim Levental d46f169c1a
Fix kwarg annotation in eager (#747) 2022-04-11 17:35:42 -05:00
Maksim Levental 66de821eaf
small framework plus build_script_function (#745) 2022-04-11 16:53:52 -05:00
Maksim Levental 18ef40acaf
Fixes a bug in use of upstream `normalize_function` in our `normalize_args_kwargs` (in eager mode) and introduces unit tests. (#740)
NB: `shouldnt_normalize2` and `shouldnt_normalize3` currently XPASS i.e., args *will* successfully normalize despite being incorrect due to an [upstream bug](https://github.com/pytorch/pytorch/issues/75342).
2022-04-11 16:17:44 -05:00
gpetters94 9ec0683e92
Add 2D case for convolution (#693) 2022-04-08 00:47:57 -04:00
gpetters94 fa0b24a73c
Rename optional list types (#643) 2022-04-07 18:15:51 -04:00
Sean Silva e7721fb784 Fix error message.
RefineTypes doesn't handle shape refinement anymore.
2022-04-07 14:46:44 -07:00
Prashant Kumar 1d5b5a89e8 [LINALG] Add torch.layout information
torch.layout information has been added.
2022-04-07 20:47:49 +05:30
Ahmed S. Taei eaf34fa02b
Add bazel build support (1/N) (#706)
This PR adds rules for building the compiler part with bazel, a followup PRs will build the python bindings.
2022-04-06 11:20:39 -07:00
Prashant Kumar fb8cb0c5f3 [LINALG] Add the lowering of `aten.ne.Scalar` op
The lowering of `aten.ne.Scalar` op has been added to
the linalg backend.
2022-04-05 21:07:28 +05:30
Ramiro Leal-Cavazos 5620fe030e
Add 1D, weight, and reduction support to nll_loss_backward (#729)
This commit adds the following support to the op `nll_loss_backward`:
- `input` tensor can be rank-1
- `weight` parameter
- `reduction` parameter
- `target`, `grad_output`, `total_weight` can be rank-0
- Checks that input tensors are of the expected type
2022-04-04 10:57:49 -07:00
Clément Fournier 886ad169e5
Fix out-of-tree build of torch-mlir-dialects (#726)
Follows up on #623 for out-of-tree builds of torch-mlir, which
added building `torch-mir-dialects` as a subdirectory.

Our goal is to support both in-tree and out-of-tree builds of
`torch-mlir` with minimum hassle, for instance by using the same
variable names in both setups.

Specific changes to `externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt`:
- We use `MLIR_FOUND` to detect that it is being build as a subdirectory
and the llvm+mlir cmake infrastructure is already set up (via
find_package in the parent build) as opposed to an in-tree build.
- For in-tree, the setting of variables and loading of llvm+mlir cmake
infrastructure is now conditionally performed.
- For in-tree, the names of cmake variables being defined for are
adjusted to match those `llvm-project` makes available through
`find_package(MLIR REQUIRED CONFIG)`, under the assumption that those
are the more "standardized" names.

Co-authored-by: Clément Fournier <clement.fournier@amd.com>

Co-authored-by: Liam Fitzpatrick <liam.fitzpatrick@xilinx.com>
2022-04-04 11:37:28 +02:00
Sean Silva e1c7c1f9c5 Update diagram for TOSA backend. 2022-04-01 22:46:25 +00:00
Sean Silva 14cf87633c
Add link to forum post describing `__torch_dispatch__` 2022-04-01 10:10:43 -07:00
Ramiro Leal-Cavazos 51d4d55f8a
Add support for multi-dim input to `index_put_impl` (#722)
This commit adds support for multi-dimensional tensors as input to the
`_index_put_impl_` op. The support was to some degree already there,
since `ScatterOp` already supports multi-dimensional tensors. This
commit also adds a bit more error checking to `index_put` and
refactors the code for creating `ScatterOp`s to mimic the way one
would make a `Linalg::GenericOp`.
2022-03-31 09:27:21 -07:00
Anup Gangwar ccf924d3df
tosa] Support for Aten[Gelu|GeluBackward] ops (#720)
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
2022-03-30 17:00:55 -07:00
Sean Silva c17c0a6ba2 Fix for 0-size dim inferred incorrectly.
The issue was in the canonicalizer for torch.aten.ge.int -- in cases
where the operands were swapped, it would miscompile. This issue is
fixed and folding support generalized to `torch.aten.size.int < 0` as
well.

Fixes #716
2022-03-30 16:36:15 -07:00
Sean Silva 8250f50c81 Attempt to set Python package version to the snapshot identifier.
This should make the releases sort properly when `pip`'s
`-f`/`--find-links` argument is used.
2022-03-30 17:54:11 +00:00