Commit Graph

896 Commits (f77d88390a8e3c4bdc2172a0f0342d0df21c598d)

Author SHA1 Message Date
saienduri a2e694df40
add e2e support for torch.eye operations (aten.eye, aten.eye.m) (#2478) 2023-11-01 11:23:28 -07:00
Daniel Garvey 1d41f7b6fe
Rework AtenEmptyStridedOp checks (#2537)
Now using Value instead of Ints. Trades compile failure for a runtime
assert
2023-10-31 22:56:54 -05:00
xiaolou86 4199feffed
Fix typos in comments (#2539)
Fix typos in comments
2023-10-31 20:10:47 -07:00
JianzheXiao e8706957c0
[Torch Dialect] Add Support for aten.unflatten.int (#2475)
As title, Add support for aten.unflatten.int, support dim to be negative
and one of the sizes' elements to be -1
2023-10-31 15:36:16 +08:00
Yuanqiang Liu e7282487ea
[Torch Dialect] support aten.glu (#2531) 2023-10-26 10:36:18 +08:00
Sarthak Gupta 7633619ed2
[torch] Implement stronger verifiers for non-value semantic ops (#2519)
Attempt to solve https://github.com/llvm/torch-mlir/issues/2490

Changes for Non Value Semantic Ops having the
`IsTrailingUnderscoreInplaceVariant` trait :
- AnyTorchTensorType -> Torch_NonValueTensorType
- AnyTorchOptionalTensorType -> AnyTorchOptionalNonValueTensorType
- AnyTorchListOfOptionalTensorType ->
AnyTorchListOfOptionalNonValueTensorType
- AnyTorchListOfTensorType -> AnyTorchListOfNonValueTensorType

Created three new tensor types for optional and list non value tensors.
2023-10-21 09:09:55 -07:00
Ze Zhang f2c53b8ca5
Add aten.isclose support and its torch-to-tosa lowering (#2512)
Add aten.isclose op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests


To test e2e tosa lowering:
`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2023-10-16 09:44:53 -07:00
Ze Zhang e649e06b7b
Add aten.unflatten.int support and its torch-to-tosa lowering (#2509)
Add aten.unflatten.int op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests

To test e2e tosa lowering:

`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2023-10-13 18:39:41 -07:00
Ramiro Leal-Cavazos 2e5d65064c [linalg] Add handling for leadin and trailing size-1 dims in ViewOp
This commit adds to the lowering of `aten.view` handling for the
following cases:

- `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)`
- `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))`
- `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)`
- `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)`
2023-10-03 23:04:52 +00:00
Ramiro Leal-Cavazos 1c508af0ba Revert "[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)"
This reverts commit 7c6b9d2445.
2023-10-03 23:04:52 +00:00
Vivek Khandelwal ca6ce8974f [MLIR][TORCH] Add support for int8 dtype for sub, add, and bitwise_and op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-10-03 22:12:31 +05:30
Jae Hoon (Antonio) Kim 32d9b20bde
Add linspace/cumprod/roll ops (#2498)
Add linspace/cumprod/roll ops to ODS and add shape inference functions
to make it work with LTC.

Also, add some tensor utils to LTC library for searching for non-detach
copy nodes.
2023-10-03 11:01:07 -04:00
Vivek Khandelwal 9293326e1e [MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scalar op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-10-02 13:06:59 +05:30
Vivek Khandelwal c434736ee9 [MLIR][TORCH] Add support for conversion to int8 dtype
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-10-02 09:48:46 +05:30
Vivek Khandelwal 71ac62f3a8 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-09-28.

aten.baddbmm changes done because upstream PyTorch has now added
support for fp16 gemm on CPU.
Refer: 9399e0b1ff
2023-10-02 09:48:32 +05:30
saienduri 4e1dd3bf10
add e2e support for torch.log10 (#2479) 2023-09-28 10:17:03 -07:00
Ramiro Leal-Cavazos 7c6b9d2445
[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)
This commit adds to the lowering of `aten.view` handling for the
following cases:

- `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)`
- `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))`

Fixes: https://github.com/llvm/torch-mlir/issues/2448
2023-09-27 09:09:30 -07:00
Vivek Khandelwal 7760bda8ee build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-09-26.

aten._convolution.deprecated changes done because upstream PyTorch has
now added support for fp16 native convolution on CPU.
Refer: 7c9052165a

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-27 16:24:58 +05:30
Bruce Kim a520d39f84
[MLIR][TORCH] Add device "cpu" support for aten.to.dtype_layout op (#2481)
This PR adds device="cpu" support for `aten.to_dtypeLayout` op and
corresponding e2e test suit.
(refer:  PR https://github.com/llvm/torch-mlir/pull/812/)
2023-09-25 10:00:19 -04:00
Gleb Kazantaev 059041e0fe
[LTC] Support torch.ones/zeros/arange ops (#2440) 2023-09-21 13:25:14 -04:00
David Gens 023fc90072
[Torch Dialect] add avg_pool 2d and 3d op variants (#2473)
Adds ODS for `avg_pool2d` and `avg_pool3d`, including their backward and
`adaptive_` variants.
2023-09-20 13:47:08 -04:00
Bruce Kim 40913a36c2
[MLIR][TORCH] Add E2E support for aten.empty_strided decomposition op (redo PR) (#2459)
Making the same PR with #2457, as I accidentally thought the review was already made and merged it (reverted).

Add decompose empty_strided op.
Referring to #1776, this decomposition op only supports default stride values, because accessing the tensor or indexing over that, the indices are determined by the strides.
In MLIR, this is not implicitly supported but assumes that the strides are default while iterating over the tensor.
2023-09-13 10:04:31 -07:00
Vivek Khandelwal 4b4c38da46 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-09-13.
Ref: 464f9c3725

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-13 21:25:21 +05:30
Stella Laurenzo 078d1e1a1d
Remove mlir-hlo (replace with stablehlo). (#2460)
We just have to do this: I ran into an issue today where I needed to make a one line patch to stablehlo to work around a compiler issue, and it is completely unapparent how to do so given that the mlir-hlo repo is a read-only export and is at the tail end of a multi-week integration chain from the open-source stablehlo repo.

We've discussed this often enough and gotten +1 from everyone that they are ok with taking the e2e testing hit if it becomes necessary: It is necessary as the current situation is unmanageable.

Looking at it, I expect it wouldn't actually be very difficult to build a little runner binary out of the stablehlo interpreter and subprocess call that in order to get the testing coverage back. I leave that as an exercise to the users of this part of the stack and recommend following the breadcrumbs from the deleted python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py file and the main.py changes.

Note that I am pointing us at a stablehlo fork for the moment until it is apparent that we don't need to carry any local patches to it. We can update this in a few days if everything is clear.
2023-09-12 19:10:02 -07:00
Ramiro Leal-Cavazos 106b58597a
Revert "[MLIR][TORCH] Add E2E support for aten.empty_strided decomposition op (#2457)" (#2458)
This reverts commit 97bec86a8b.
2023-09-12 13:57:47 -07:00
Bruce Kim 97bec86a8b
[MLIR][TORCH] Add E2E support for aten.empty_strided decomposition op (#2457)
* implemented e2e test case, shape, dtype func

* AtenEmptyStrided decompose op implemented

* xfailed test module in ltc
2023-09-12 13:37:02 -07:00
Arham Khan 82456eefed
[MLIR][TORCH] add E2E support for aten.new_full (#2425)
* implement aten.new_full

* remove extraneous tests
2023-09-12 09:29:08 -05:00
Vivek Khandelwal 23b72244b1 [MLIR][TORCH] Add different dtype support for aten.bmm op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-12 12:38:46 +05:30
Yuanqiang Liu 1f20b7275d
[Torch Dialect] add canonicalize for aten.min.other (#2452) 2023-09-11 17:28:22 +08:00
Bruce Kim 27b55b1d5f
implemented complex tensor aten mul (#2444) 2023-09-07 13:29:15 -07:00
Jiawei Wu b411a40b3d
[Torch Dialect] emit aten.__or__Tensor Op (#2437)
* emit aten.__or__TensorOp

* bug fix

* remove convert to stablehlo

* code style refinement
2023-09-06 14:21:51 +08:00
Stella Laurenzo fcb3b718a5 Properly guard clang-specific pragma.
Avoids unsupported pragma warning on GCC.
2023-09-06 00:43:50 -07:00
Jerin Philip 9cb5d38cd1
[MLIR][TORCH] Add E2E `torch.aten.prod_dim_int` (#2423)
Uses the existing reduction codepath, adding modifications or branches
required alongside for prod.
2023-09-05 13:38:51 -07:00
Vivek Khandelwal 3841fe3035 [MLIR][TORCH] Add StableHLO lowering for embedding_bag.padding_idx op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-05 21:32:23 +05:30
Jiawei Wu d62045f64d
emit aten.max.other op (#2436) 2023-09-05 10:52:32 +08:00
Yuanqiang Liu e9ab8ceb1c
[Torch Dialect] support aten.split_with_sizes (#2431)
* [Torch Dialect] support aten.split_with_sizes

* update
2023-09-04 09:59:26 +08:00
Bruce Kim cd1c7df8be
[MLIR][TORCH] Add E2E support for view_as_real op (#2419)
* view_as_real test case, allow dtype in testutils.randn

* abstract python upstream func implemented

* fixed upstream dtype func, implemented view_as_real backend op

* formatted AtenViewAsRealOp, removed change in e2etest/framework

* removed test suit from reshape_like.py, because it's moved to basic.py

* implemented C-API wrapper for mlirComplexF128 type

* fixed torch.complex dtype width in MLIR and Torch MLIR, deleted float16 dtype dict

* Changed IR input of aten fft_fft unit test

* code refactored

* code refactored and fixed ci test

* refactored: removed white spaces, and rolled back to having both input/output affine expr

* refactored: deleted output affine expr to reduce redundancy

* xfail ltc backend

* removed ComplexImag and ComplexReal from torchdynamo xfail set

* copied and pasted from main branch as there's no change to be made in this file

* refactored abstract_interp_lib_gen.py

* refactored: torchtypes.td, formatted, removed commented out code
2023-09-01 21:12:01 -07:00
Quinn Dawkins 1fc4314b62
Add folder for aten.broadcast_to on unchanged static shapes (#2421) 2023-09-01 14:50:34 -04:00
Vivek Khandelwal 729386c9d8 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-09-01.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-01 22:07:51 +05:30
Vivek Khandelwal 5c43daa3bf [MLIR][TORCH] Add e2e support for aten.pow.Scalar op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-31 21:43:24 +05:30
Vivek Khandelwal aa15f0d4ca build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-08-30.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-31 16:23:34 +05:30
Gleb Kazantaev 6b02e9a926
[LTC] Tensor[]? support operands type support using partial codegen (#2410)
* Tensor[]? support operands type support using partial codegen

* aten.index.Tensor support via partial codegen

* Add torch.index_put tracing support

* Added optional tensor list type support for LTC/TorchMLIR lowering

* Added comments

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
2023-08-30 06:29:39 -04:00
JianzheXiao 17d02811d5
[Torch Dialect] add folder for aten.any.bool (#2388)
* update

* update

* update

* update

* update

* update

* update
2023-08-30 17:29:03 +08:00
Arham Khan c42d2beb6e
[MLIR][TORCH] add E2E support for aten.min op (#2422)
* impl aten.min op

* remove extraneous test
2023-08-29 12:12:41 -05:00
Zhekun(Josh) Zhang 5282324c68
[Importer] fix has value semantic return type (#2404)
* fix value semantic return

* address comments

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
2023-08-29 10:14:09 +08:00
David Gens ca34b9c4fc
add max_pool3d (#2386) 2023-08-28 19:01:55 -04:00
Arham Khan bc6bba9077 add nondefault test case, add to illegal ops in backend contract 2023-08-28 10:52:16 +05:30
Arham Khan 8855fa3ace amend dtype function 2023-08-28 10:52:16 +05:30
Arham Khan a80bc42521 dtype test case 2023-08-28 10:52:16 +05:30
Arham Khan 610d836fd2 impl aten.elu as decomposition 2023-08-28 10:52:16 +05:30
Arham Khan 12eadccc07 add e2e support for aten.elu 2023-08-28 10:52:16 +05:30
Jiawei Wu 4339c00f1b
[Torch Dialect][stablehlo] emit aten.rand op and add converter to stablehlo (#2413)
* [Torch Dialect] emit aten.rand op and add converter to stablehlo

* add failed tests for torchdynamo backend

* add failed test for linalg backend
2023-08-27 21:56:36 +08:00
Gleb Kazantaev 3dd29f9d5d
Update Torch ODS list with new ops (#2361)
* [LTC] Add shape_inference_(add|uniform)

* Add torch.multinomial op.

* Update ods gen; add normal_functional and erfinv ops support

* New TorchMLIR ops: clamp_min.Tensor, clamp_max.Tensor, xlogy, binary_cross_entropy, log_sigmoid_forward, sigmoid_backward, cosine_embedding_loss, scatter.reduce

* Improve the shape inference logic of whereOp

- Infer the result tensor according to the broadcasting semantics

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>

* Added aten::sgn

* Add shape inference logic for hardtanh_backward op

* Added new Torch-MLIR ops

Co-authored-by: GlebKazantaev <gleb.nnstu@gmail.com>

* Add support for elu lowering

* Add support for elu_backward lowering

* Support fmod, remainder, and floor_divide

Emit generated op defs for the remainder.Tensor and fmod.Tensor

Add shape inference impelementations for remainder.Scalar, fmod.Scalar
and floor_divide.Tensor

* Add shape inference logic for im2col

- pytorch.nn.unfold gets decomposed into im2col

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>

* Add aten::eye and aten::eye.m support

* Add tracing for linalg_qr

* Update GeneratedTorchOps.td

* Update xfails

* Fix unbound variable issue in torch_ods_gen

---------

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: zihaoc-cerebras <zihao.chen@cerebras.net>
Co-authored-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
Co-authored-by: Gokul Ramakrishnan <gokul.ramakrishnan@cerebras.net>
Co-authored-by: glebk-cerebras <111300564+glebk-cerebras@users.noreply.github.com>
Co-authored-by: Behzad Abghari <behzad.abghari@gmail.com>
Co-authored-by: Ahmed Elkoushy <ahmed.elkoushy@cerebras.net>
2023-08-21 06:36:39 -04:00
Gleb Kazantaev 5743b6d4ac
LTC multi-output operations support (#2362)
* LTC/TorchMLIR multi-output operations support

* Update torch-mlir jit lowering to support ops with dynamic number of outputs

* Added support for aten::split_copy, aten::split_with_sizes_copy

* Fix native function for aten::split; cleanup code

* Fix TorchMlirTensorList lowering

* Remove xfails
2023-08-20 16:32:11 -04:00
Simon Camphausen d77b9cf7ae
[TOSA] Fix conversion for depthwise convolutions (#2398)
* [TOSA] Fix conversion for depthwise convolutions

* Add e2e tests for depthwise and grouped convolutions

Co-authored-by: Lucas Camphausen <lucas.camphausen@iml.fraunhofer.de>
2023-08-18 08:15:54 -07:00
Jiawei Wu 60bad54f27
[Torch Dialect] replace none-index in aten.Index.Tensor's param by manually generating it (#2344)
* [Torch Dialect] replace none-index in aten.Index.Tensor's  param by manually generating it
Co-authored-by: Jiawei Wu <wujiawei.aml@bytedance.com>
Co-authored-by: Jianzhe Xiao <jianzhe.xiao@bytedance.com>

* minor typo fix

* add new failed e2e tests for ltc

* fix typo

* Address comments

* Add more e2e tests

* add failed e2e tests for LTC

* address comments

* remove decomposition for AtenIndexTensorHackedTwinOp
2023-08-15 19:36:08 +08:00
Ramiro Leal-Cavazos ff762100b8
Add handling of namespaces to library generator (#2391)
When using custom ops, sometimes PyTorch will insert namespaces to the
abstract interpretation function name in the format:
`__torch__.{namespace_1}.{namespace_2}...{op_name}`.  The extra
namespaces are not part of the abstract interpretation function name,
so it needs to be removed before generating the library of MLIR
snippets of abstract interpretation functions. This commit adds
support for removing the namespace information.
2023-08-11 09:56:19 -07:00
Vivek Khandelwal e61ef1ee54 [MLIR][TORCH] Add support for aten._unsafe_index_put.hacked_twin op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-11 08:57:01 +05:30
Vivek Khandelwal f0a8f273f7 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-08-10.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-10 21:59:20 +05:30
Vivek Khandelwal ee6c87ef5b [MLIR][TORCH] Add support for dtype arg for softmax.int op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-08 21:54:47 +05:30
JianzheXiao 38b049eb1a
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.

update PyTorch version to 2.1.0.dev20230729 (#2354)

- torch version: 2.1.0.dev20230729
 - torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
 - torchvision version: 0.16.0.dev20230729

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

update PyTorch version to 2.1.0.dev20230730 (#2356)

- torch version: 2.1.0.dev20230730
 - torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
 - torchvision version: 0.16.0.dev20230730

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

update PyTorch version to 2.1.0.dev20230731 (#2359)

- torch version: 2.1.0.dev20230731
 - torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
 - torchvision version: 0.16.0.dev20230731

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

LTC->MLIR Debug Info support (#1922)

* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>

build: update llvm tag to 41895843

Summary of changes:
- Update tags
  llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
  mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

update PyTorch version to 2.1.0.dev20230802 (#2366)

- torch version: 2.1.0.dev20230802
 - torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
 - torchvision version: 0.16.0.dev20230802

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

Change Python version from 3.10 to 3.11 in installation instructions (#2370)

Add CITATION file (#2371)

Add packaging as an install dependency (#2369)

Needed by `torch_mlir._version`. Resolves #2368.

[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)

* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op

update PyTorch version to 2.1.0.dev20230803 (#2372)

- torch version: 2.1.0.dev20230803
 - torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
 - torchvision version: 0.16.0.dev20230803

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

Prevent failed stable CI job from cancelling nightly jobs (#2373)

The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.

[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)

update

update xfail sets

update xfail_sets

update

fix xfail_sets

update:

update

update:

update

parent 22e88d523b1970b2e904eb5421d49d987a3d255e
author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800
committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.

update PyTorch version to 2.1.0.dev20230729 (#2354)

- torch version: 2.1.0.dev20230729
 - torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
 - torchvision version: 0.16.0.dev20230729

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

update PyTorch version to 2.1.0.dev20230730 (#2356)

- torch version: 2.1.0.dev20230730
 - torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
 - torchvision version: 0.16.0.dev20230730

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

update PyTorch version to 2.1.0.dev20230731 (#2359)

- torch version: 2.1.0.dev20230731
 - torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
 - torchvision version: 0.16.0.dev20230731

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

LTC->MLIR Debug Info support (#1922)

* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>

build: update llvm tag to 41895843

Summary of changes:
- Update tags
  llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
  mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

update PyTorch version to 2.1.0.dev20230802 (#2366)

- torch version: 2.1.0.dev20230802
 - torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
 - torchvision version: 0.16.0.dev20230802

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

Change Python version from 3.10 to 3.11 in installation instructions (#2370)

Add CITATION file (#2371)

Add packaging as an install dependency (#2369)

Needed by `torch_mlir._version`. Resolves #2368.

[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)

* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op

update PyTorch version to 2.1.0.dev20230803 (#2372)

- torch version: 2.1.0.dev20230803
 - torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
 - torchvision version: 0.16.0.dev20230803

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>

Prevent failed stable CI job from cancelling nightly jobs (#2373)

The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.

[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)

update

update xfail sets

update xfail_sets

update

fix xfail_sets

update:

update

update:

add support for adaptive_pool_id

update xfail sets

update xfail_sets

update

fix xfail_sets

update:

update:

* update

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-08-05 07:48:09 +08:00
Jiawei Wu 20a2b68ed6
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355) 2023-08-04 09:05:34 +08:00
Jiawei Wu 6db92d1b14
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
2023-08-03 16:21:14 +08:00
Vivek Khandelwal a374c39106 build: update llvm tag to 41895843
Summary of changes:
- Update tags
  llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
  mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-08-02 21:18:14 +05:30
Gleb Kazantaev fb52a73cbe
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
2023-08-02 10:29:11 -04:00
Vivek Khandelwal 0109bf705b
[MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-07-28 09:45:12 -07:00
JianzheXiao 31ef08b63d
[Stablehlo]Add support for AvgPool1dOp (#2268)
* Add support for AvgPool1d

* Update AbstractInterpLibrary

* support avgpool1d in linalg

* refactored code

* fix nit problem
2023-07-25 14:09:53 +08:00
Jiawei Wu d57f67e7f8
[Torch Dialect] emit aten.nonzero, aten.nonzero_numpy, aten.nonzero_static op (#2338)
By the way, this PR also adds the missing shape function for aten.masked_select.
2023-07-25 09:01:19 +08:00
Ramiro Leal-Cavazos 4a96e716c0
Use `register_buffer` to make `Add_Module` test work on lazy tensor (#2332)
Doing `module.to('lazy')` only moves the module member tensors to the
device if they are created with `self.register_buffer` or
`self.register_parameter`. Since the `self.tensor` tensor in
`Add_Module` test is currently not created using the `self.register_*`
methods, it is not being moved from CPU to lazy device, which is
causing the test to fail on LTC backend. This commit uses
`self.register_buffer` to fix the test on LTC backend.

This commit also seems to fix the test for torchdynamo.
2023-07-24 09:07:13 -07:00
Alexandre Rames 1e468e8294 Fix canonicalization of `torch.prim.TupleUnpack`. 2023-07-20 20:08:46 +02:00
Jiawei Wu 9535be7903
[Torch-Dialect] emit aten.narrow.Tensor op and decompose it to aten.narrow op (#2297) 2023-07-20 16:46:44 +08:00
Matthias Gehre 64d7626a52
Fixes for split tensor and slice (#2314)
* RecomposeComplexOps: Remove dead slice op

* lib/Dialect/Torch/IR/TorchOps.cpp: Fold slice ops even when they are on non-value tensors

* lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix slice start/end out of range/none

* lib/Dialect/Torch/IR/TorchOps.cpp: AtenSliceTensorOp::fold: Fold slices that go from 0:int_max

* More tests for aten.split.Tensor
2023-07-20 09:53:54 +02:00
max 0650efe7c0 Conform to Python custom exception api 2023-07-19 21:00:55 -05:00
Jiawei Wu 3f843c8fd9
[torch-dialect] fix aten.type_as op's folder (#2283)
[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtype
2023-07-20 09:51:58 +08:00
Ramiro Leal-Cavazos 718f53ff8a
Fix handling of `!torch.number` in abstract interpretation library (#2309)
In PyTorch, the `NumberType` is equal to `Union[int, float,
complex]`. However, the abstract interpretation library was treating
the `NumberType` as `Union[int, float]`, resulting in type mismatches
when reifying certain dtype functions. This commit fixes the type
inconsistency by having the abstract interpretation functions take as
an input a `Union[int, float, complex]` for the ops that take
`!torch.number` inputs.
2023-07-17 09:52:04 -07:00
Chi_Liu 5706697e0b
[TOSA] Add aten._index_put_impl support (#2031)
Add e2e support by add  "tosa-to-scf"
2023-07-17 09:51:24 -07:00
Matthias Gehre 06c9bd08e0
lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix legalization of comparions where the input type is bool (#2304) 2023-07-17 09:49:04 +02:00
Matthias Gehre f8e75f659d
Add make_fx_tosa variant to end2end tests (#2240)
* Add make_fx_tosa variant to end2end tests

* e2e_testing/xfail_sets.py: Add make_fx_tosa xfail for stable
2023-07-13 15:07:54 +02:00
nithinsubbiah 91c6454618 Filter out empty strings while generting function signature 2023-07-13 13:51:54 +05:30
Abhishek Varma 6c9ba4ce95
[Torch-to-Linalg] Add dynamic dimension support for BroadcastTo op (#2174)
-- This commit adds support for dynamic dimension in BroadcastTo op.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-07 10:01:51 -07:00
Jiawei Wu c7fa42b7d3
[Torch Dialect] Add canonicalizer for aten.to.other op (#2273)
Canonicalize aten.to.other to prim.device + prim.dtype + aten.to.device
Co-authored-by: wujiawei.aml <wujiawei.aml@bytedance.com>
2023-06-30 09:43:08 +08:00
Yuanqiang Liu 449cfb8375
[Torch Dialect] add more scalar op folders (#2265) 2023-06-29 10:37:13 +08:00
Yuanqiang Liu 859885c1d3
[Torch Dialect] Support aten.native_dropout (#2259)
* [Torch Dialect] Support aten.native_dropout

* update
2023-06-27 14:19:33 +08:00
Yuanqiang Liu 1ea2b57ab7
[Torch Dialect] add folder for aten.add (#2264)
* [Torch Dialect] add folder for aten.add

* update

* update

* update
2023-06-27 10:55:28 +08:00
Yuanqiang Liu 64afc08dab
[Torch Dialect] add missing one_hot dtype function (#2143)
* [Torch Dialect] add missing one_hot dtype function

* update

* update

* update
2023-06-23 16:11:33 +08:00
Ramiro Leal-Cavazos 6f2bf31291
Fix single-element tuple construction in abstract interp library (#2258)
Single element tuples in Python need a comma after the
element. However, the `registry.py` file, which generates the expected
abstract interpretation function signatures, was not inserting the
comma. This commit changes the expected signature generator to add a
comma after the last element in any non-empty default tuple argument.
2023-06-22 11:27:40 -07:00
Yuanqiang Liu 96b14e952e
[Torch Dialect] Support aten.device.with_index (#2254) 2023-06-23 01:07:14 +08:00
Abhishek Varma a0d2789840 [MLIR][TORCH] Add e2e support for aten.alias
-- This commit adds e2e support for aten.alias op.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-21 12:15:31 +05:30
Vivek Khandelwal f6a6cfea4e
[MLIR][TORCH] Add support for negative index values for index.Tensor op (#2233)
This commit adds the support for index.Tensor op when the index values
are negative. This commit wraps around the index values by checking
their values at run time.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-06-16 14:21:04 -05:00
Vivek Khandelwal ab8b23e767 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-05-16.
This commit removes the test `BaddbmmDifferentDtypesModule_basic`
since PyTorch expects all operands to have the same dtype.
Ref: 2abad0c184

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-06-15 17:53:16 +05:30
Yuanqiang Liu bba0f5891b
[Stablehlo] add conversion for AtenFlipOp (#2163) 2023-06-15 10:27:34 +08:00
Yuanqiang Liu 7c6961bcbf
[Torch Dialect] Support aten.cuda and add canonicalizer for aten.cuda (#2231) 2023-06-14 09:56:39 +08:00
Maksim Levental 0caaf8d32a
Bump LLVM (#2176)
* Bump LLVM

---------

Co-authored-by: Matthias Gehre <matthias.gehre@xilinx.com>
2023-06-13 16:17:23 +02:00
Christopher McGirr b461daa06e
fix(TorchToTosa.cpp): adjust torch->tosa div conversion (#2200)
check the return type of the division to figure out whether to use
the floating point implementation of a division or to use the integer.

the issue rose from the fact that the inputs are all integer but the
result was casted to floating point. The conversion then chose to
use the integer implementation of division which is not legal in tosa
when all the inputs get casted to floating point.

fix(TorchToLinalg): AtenDivScalarOp

upcast self operand as well if applicable, the self operand must also
be casted to float as it can be an integer.
2023-06-12 11:18:38 +02:00
Tiago Trevisan Jost cc75557119
feat: support unchanged dimensions in torch.aten.broadcast_to operation. (#2204) 2023-06-12 11:17:25 +02:00
Matthias Gehre 4e2ba2e0af
Support aten.sign (#2205) 2023-06-10 20:45:35 +02:00
Matthias Gehre 0959b502ae
Print name of the backend when tests fail to help debugging issues in CI (#2210)
* Print name of the backend when tests fail to help debugging issues in CI

* Extended test python/test/torchscript_e2e_test/compilation_failure.py
2023-06-09 10:47:07 +02:00
Yuanqiang Liu 5a7bf4e4cb
[Torch Dialect] Add canonicalize pattern for aten.is_floating_point (#2194)
* [Torch Dialect] Add canonicalize pattern for aten.is_floating_point

* implement as fold

* add lit test
2023-06-07 17:05:31 +08:00
Matthias Gehre 816880774b
Fix version comparison against stable (#2209) 2023-06-07 10:19:38 +02:00
JianzheXiao e4f8fb1b8c
[Torch Dialect] add support for AtenIsnanOp (#2170)
* add support for mhlo

* Add Test for torch.ne

* fix torch.ne shape/add static test case

* add support for static torch.ne

---------

Co-authored-by: root <root@n31-177-039.byted.org>
2023-06-07 10:06:27 +08:00