Commit Graph

638 Commits (revert-2869-AtenSlice-folder)

Author SHA1 Message Date
Andreas Falkenberg 5862854bc8
[ONNX][TORCH-MLIR] LayerNorm (#2716)
Layer Normalization using the torch.aten.native_layer_norm 

https://github.com/nod-ai/SHARK-Turbine/issues/325
2024-01-11 14:27:04 +05:30
kumardeepakamd 29569713f3
support for onnx.expand operator (#2729)
maps onnx.expand to torch aten broadcast_to, three tests added

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
2024-01-10 13:05:37 -08:00
Vivek Khandelwal 208ae35583 [MLIR][ONNX] Add TorchToOnnx Support for DepthToSpace op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 17:50:47 +05:30
Vivek Khandelwal 4707d3bdc6 [MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 16:24:06 +05:30
Vivek Khandelwal 35e8f86792 [MLIR][ONNX] Add OnnxToTorch support for Dropout and Elu op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 16:23:55 +05:30
John Wu 4e5e34d215
[MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696) 2024-01-03 19:41:10 -08:00
Xida Ren (Cedar) 1778314620
add basic cumsum. this doesn't support the exclusive and reverse attrs (#2717)
fixes #2711
2024-01-03 09:52:59 -08:00
Xida Ren (Cedar) 9fc212ea9a
support Onnx opset 1-13 ReduceMean where axes is supplied as an attr (#2703)
(instead of an input)

Addresses part of #2689. fixes #2702
2023-12-28 09:31:41 -08:00
Xida Ren (Cedar) d560698e3d
Lower `onnx.split` to `torch.aten` (#2686) 2023-12-27 17:53:07 -08:00
aldesilv 2d796b7502
lower onnx max op to torch aten maximum op (#2618)
lower onnx min op to torch aten minimum op
2023-12-27 11:07:35 -08:00
aldesilv 336cfb64b5
OnnxToTorch support for onnx.Mul op (#2699) 2023-12-27 10:50:08 -08:00
Xida Ren (Cedar) 6847fc1fc6
Fix since-opset too high (#2701)
Addresses two of the ops from
https://github.com/llvm/torch-mlir/issues/2689

https://github.com/llvm/torch-mlir/issues/2700
2023-12-27 10:08:09 -08:00
aldesilv abc6b0a25a
onnx to torch pow support (#2656) 2023-12-27 09:34:48 -08:00
Vivek Khandelwal 4f252c88b4
[MLIR][ONNX] Add OnnxToTorch support for GlobalAveragePool op. (#2692)
This commit adds the OnnxToTorch support for GlobalAveragePool op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-26 10:25:31 -08:00
saienduri ee75e8d1ae
[MLIR][ONNX] Add OnnxToTorch support for Reshape Op (#2698)
This commit adds the OnnxToTorch support for Reshape op.
2023-12-26 10:20:13 -08:00
Vivek Khandelwal 0849fd0a06 [MLIR][ONNX] Fix onnx.conv lowering to handle bias tensor
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-12-22 16:36:21 +05:30
Vivek Khandelwal 9a72c6584e [MLIR][ONNX] Add OnnxToTorch support for BatchNormalization and Concat op.
This commit adds the OnnxToTorch support for BatchNormalization and Concat op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-22 11:25:33 +05:30
Stella Laurenzo ccd469ca0d
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681)
Changes made during upstreaming:

* Removed comments attributing some copied code back to torch-mlir
(since it is now repatriated).
* Re-organized imports.
* Inlined RefMapping/RefTracker and TypeSubclassMap from an external
utility module.
* Added FxImporter class comments.
* Updated stack trace extraction to be fail safe.
* Added an entry-point for `import_frozen_exported_program` which uses
the shiny new upstream `torch.export.export()` API (versus the
lower-level/older API that Turbine is presently using). This
necessitated a small FX rewrite to line external state management up
with current conventions.
* Adapted one of Turbine's importer tests to go with this initial
submission. Turbine unfortunately has a lot of more-integration-ey
tests, and I would like to extract those as more of unit tests of the
importer features and upstream them that way vs trying to copy directly.
For now, one overall test with the initial submission gets us moving.

I acknowledge that there are some code quality things that could be
improved in this submission: this was authored over the course of many
months (and often via some trial and error). I would like to keep it
relatively converged with the downstream for the next few steps while
getting the test suite upstreamed. And then it will be easier to take a
hygienic pass through the code.

Including co-authors for contributors in the git log of the original
repository.

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Avinash Sharma <aviator1994@gmail.com>
Co-authored-by: Arham Khan <arhammkhan@gmail.com>
Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu>
Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-21 08:40:10 -08:00
John Wu 46f2cb50dc
[onnx] Lower onnx.HardSigmoid to torch (#2682)
The expression for HardSigmoid in Onnx
(https://onnx.ai/onnx/operators/onnx__HardSigmoid.html): max(0, min(1,
alpha * x + beta))

is inherently different from HardSigmoid in Torch
(https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html)
which is: if x < -3 -> 0
elif x > 3 -> 1
else x/6 + 1/2

That being said, it was just better to compute out the entire expression
when translating the Onnx expression to Torch mlir, which is done in
this PR. Some of the logic is shared from the files in
`DecomposeComplexOps`. Therefore, refactored some shared logic between
`DecomposeComplexOps` and `DefaultDomainGToP` and put it in a `Utils`
file.
2023-12-21 07:29:22 -08:00
Vivek Khandelwal 3226241521 [MLIR][ONNX] Add OnnxToTorch support for Conv and ConvTranspose op.
This commit adds the OnnxToTorch support for Conv and ConvTranspose op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-21 11:12:14 +05:30
Rik Huijzer 8328998172
Allow printing all IR in `torch_mlir.compile` (#2669)
This PR adds the `enable_ir_printing` option to `torch_mlir.compile`,
which can be used to print the IR for all intermediate passes.

When running the added test file via:
```shell
$ python test/python/compile.py 2> tiny.stderr
```
the file `tiny.stderr` is about 700 KB.
2023-12-20 15:08:21 -06:00
Rob Suderman 11cc92d4ab
[onnx] Lowerings from `onnx.tan` (#2642)
Started work on the `tan` lowerings for ONNX to Torch. Uses `sin` and
`cos` to represent a `tan`.
2023-12-20 10:09:39 -08:00
Rob Suderman a24aadbfab
[aten] Make `torch.aten.matmul` to `linalg` work for non-broadcasting case (#2659)
Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK
matmul should be legalized to a `linalg.matmul`.
2023-12-20 10:09:10 -08:00
Andreas Falkenberg ebaab4200f
[ONNX] ONNX -> TORCH for Erf (#2673)
TorchOnnxToTorch
For Erf function
2023-12-19 08:07:27 -08:00
Vivek Khandelwal 8649b84e3f
[MLIR][ONNX] Add OnnxToTorch support for AveragePool op. (#2672)
This commit adds the OnnxToTorch support for AveragePool op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-18 18:17:11 -06:00
saienduri 698ff3a736
[MLIR][ONNX] Add OnnxToTorch support for Reduction Ops (#2657)
This commit adds the OnnxToTorch support for ReduceSum, ReduceMean, and
ReduceMin ops.
2023-12-18 12:37:31 -08:00
John Wu deacb8ef38
[MLIR][ONNX] Add OnnxToTorch support for Gelu (#2647)
This commit adds the OnnxToTorch support for Gelu op.

---------

Co-authored-by: Rob Suderman <suderman@google.com>
2023-12-18 10:57:08 -08:00
Rob Suderman 791c666479
[torch] Lower `torch.aten.sinh` to `linalg` (#2662) 2023-12-18 09:15:12 -08:00
Rob Suderman ae1a6e4a5a
[onnx] Lower `onnx.Gemm` to `torch` (#2663)
General lowering for `onnx.Gemm` to `torch`
2023-12-16 10:47:58 -08:00
Andreas Falkenberg cee8563060
[onnx] Support of onnx.Greater, onnx.Less, onnx.GreaterOrEqual to Torch (#2649)
The three remaining compare operations
onnx.Greater 
onnx.Less 
onnx.GreaterOrEqual

Are also added with this push request. 
This concludes a set of basic tensor compare functions.
2023-12-16 12:42:11 -05:00
Rob Suderman 61888690bb
[onnx] Add support for `onnx.sinh` (#2643)
Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding
the `aten.sinh` operator.
2023-12-15 21:23:51 -08:00
Rob Suderman 705ea958ae
[onnx] Lowerings from `onnx.transpose` (#2641)
Lowerings for `transpose` from ONNX to `aten`. Implementation depends on
making multiple `aten.transpose` operations swapping pairs of dimensions.
As `onnx.transpose` can swap around any dimensions it may require
constructing multiple `aten.transpose`.
2023-12-15 15:30:05 -08:00
Quinn Dawkins 030b0140d4
[TorchToLinalg] Lower aten.cat to tensor.concat (#2650)
This replaces the lowering of aten.cat with tensor.concat, allowing more
efficient handling of concatenations in downstream flows. The refbackend
populates concat decomposition patterns that can be used to recover the
previous lowering.
2023-12-15 15:45:32 -05:00
Rob Suderman 061af696ce
[onnx] Lowering for `onnx.shape` to `torch` and `tensor` (#2648)
Includes the lowering from the `aten` equivalent to `tensor` operations.
2023-12-15 11:37:49 -08:00
Gaurav Shukla eb9249e601
[ONNX][MLIR] Add support for LeakyRelu and GatherElements op (#2655)
This commit adds support for `LeakyRelu and GatherElements` op in the
onnx pipeline.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-12-15 11:18:28 -08:00
saienduri f59c01fd2f
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-15 09:36:18 -08:00
Andreas Falkenberg 4ec8b9fc02
[onnx] add support for onnx.LessOrEqual (#2639)
Added the less or equal operation to OnnxToTorch. 
onnx.LessOrEqual

---------

Co-authored-by: root <andreas.falkenberg@amd.com>
2023-12-14 22:23:23 -05:00
Rob Suderman 4857606ffe
[onnx] Lowerings from `onnx.selu` (#2634)
Lowerings for `selu` lowerings for ONNX to the corresponding torch
implementations. Torch's `selu` implementation has fewer features so
we use the a generalized `elu` with the input scale set to `1.0`.
2023-12-14 08:53:47 -08:00
John Wu 42392bc845
[MLIR][ONNX] Add OnnxToTorch support for matmul ops (#2629)
This commit adds the OnnxToTorch support for Matmul.
2023-12-13 09:35:32 -08:00
Stella Laurenzo ed4df38e8d
[onnx] Add torch-mlir-import-onnx tool. (#2637)
Simple Python console script to import an ONNX protobuf to the torch
dialect for additional processing.

For installed wheels, this can be used with something like:

```
torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx
```

Or from a dev setup:

```
python -m torch_mlir.tools.import_onnx ...
```
2023-12-12 22:01:30 -08:00
Stella Laurenzo 74f7a0c9d6
Upstream the ONNX importer. (#2636)
This is part 1 of 2, which will also include upstreaming the FX
importer. I started with ONNX because it forces some project layout
updates and is more self contained/easier as a first step.

Deviating somewhat from the RFCs on project layout, I made the following
decisions:

* Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks
already has opened up that namespace and it seemed to fit. Better to
have fewer things at that level.
* Setup the build so that the root project only contains MLIR Python and
pure Python deps (like the importers), but this can be augmented with
the `projects/` adding more depending on which features are enabled.
* The default build continues to build everything whereas in
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a
`torch-mlir-core` wheel with the pure contents only.

`onnx_importer.py` and `importer_smoke_test.py` are almost verbatim
copies from SHARK-Turbine. I made some minor local alterations to adapt
to paths and generalize the way they interact with the outer project. I
expect I can copy these back to Turbine verbatim from here. I also
updated the license boilerplate (they have the same license but slightly
different project norms for the headers) but retained the correct
copyright.

Other updates:

* Added the ONNX importer unit test (which also can generate test data)
in lit, conditioned on the availability of the Python `onnx` package. In
a followup once I know everything is stable, I'll add another env var
that the CI can set to always enable this so we know conclusively if
tests pass.
* Moved the ONNX conversion readme to `docs/`.
* Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` ->
`TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the
JitIR importer and LTC options `cmake_dependent_options` for robustness.
2023-12-12 19:02:51 -08:00
Frederik Harwath b656c674ee Implement e2e support for aten.acos op
This depends on a change in the LLVM core repository which adds acos
support to the MLIR Math dialect.
2023-12-12 10:52:02 +01:00
Vivek Khandelwal 0b4422a253 [MLIR][ONNX] Add OnnxToTorch support for bitwise and math ops
This commit adds the OnnxToTorch support for BitwiseXor, BitwiseOr, Div, Equal, Cast,
Ceil, Floor, Cos, and Clip op.
This commit also adds the TorchToLinalg support for aten.clamp.Tensor and aten.clamp_min.Tensor op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-11 19:36:01 +05:30
Quinn Dawkins 141202bc01
[TorchToLinalg] Fix integer type handling for aten.mm (#2615)
Despite aten.mm requiring the input and output types match, we still opt
to maintain signedness semantics in case later passes try to do any sort
of integer type narrowing.
2023-12-07 00:13:53 -05:00
Sambhav Jain 44f6942796
Bump LLVM and StableHLO (#2598)
Bump LLVM to `5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372` and StableHLO to
`83f095e7217c897f1eccac5652600ceb944cb0e0`.

Bazel GHA:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7027647674
2023-11-28 22:12:24 -08:00
Vivek Khandelwal dc9ea08db5 [MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops
This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd,
and BitwiseNot op.
This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp.

Signed-Off By: vivekkhandelwal@nod-labs.com
2023-11-28 17:19:07 +05:30
Stella Laurenzo e06efc5136
Initial TorchOnnxToTorch conversion pipeline. (#2585)
Adds a pipeline to convert custom ops and metadata represented as
`torch.operator` custom ops to corresponding `torch` ops where possible.

This is part of a multi-part approach for building ONNX import in as a
regular feature of torch-mlir. It is focused on the conversions vs the
infra. We will end up maintaining a [pure-python
importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
to go with this in torch-mlir, and we will also maintain test case
generation utilities derived from it.

I have left substantial documentation in the README of the conversion
directory, including the recommended approach that we will take to keep
building this out.

(note that this organizes the code to coincide with the refactoring in
#2442 versus the current flat arrangement)
2023-11-21 21:02:55 -08:00
Zhekun(Josh) Zhang d67afa9e95
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543) 2023-11-21 13:26:17 +08:00
James Newling 647f2f5076
Additional tests for view lowering (#2584)
The logic for lowering the aten view op to linalg is fairly complex. 
In this PR I have tried to follow all non-failing paths through the 
lowering and add unit tests where they're missing.

There is 1 logical change to the lowering: redundant tensor.cast ops
(same source and destination type) are folded.
2023-11-20 17:35:25 -08:00
Stella Laurenzo 5eae0adff1
Breakup python pytorch deps (#2582)
This lifts the core of the jit_ir_importer and ltc out of the pt1
project, making them peers to it. As a side-effect of this layering, now
the "MLIR bits" (dialects, etc) are not commingled with the various
parts of the pt1 project, allowing pt1 and ltc to overlay cleanly onto a
more fundamental "just MLIR" Python core. Prior to this, the Python
namespace was polluted to the point that this could not happen.

That "just MLIR" Python core will be introduced in a followup, which
will create the space to upstream the FX and ONNX pure Python importers.

This primary non-NFC change to the API is:

* `torch_mlir.dialects.torch.importer.jit_ir` ->
`torch_mlir.jit_ir_importer`.

The rest is source code layering so that we can make the pt1 project
optional without losing the other features.

Progress on #2546.
2023-11-19 12:10:19 -08:00
James Newling dad1f012f6
Add verification for torch permute op (#2551)
- adds support for an optional verifier to the generated torch op
tablegen (GeneratedTorchOps.td)
- uses the above to add a verifier for the torch permute op. 

Motivation: I hit an unclear error from linalg while developing a
decomposition pass for pixel_shuffle. The error would have been clearer
if the problem had been detected earlier in the invalid aten.permute op.

Testing: new tests added. To run added tests, from the base directory
run

```
 ./build/bin/llvm-lit  test/Dialect/Torch/invalid.mlir
 ```
2023-11-15 11:47:54 -08:00
Yuanqiang Liu 3ab790c50a
[Torch Dialect] add canonicalize for aten.numel (#2562) 2023-11-11 12:16:53 +08:00
Stella Laurenzo 6961f0a247
Re-organize project structure to separate PyTorch dependencies from core project. (#2542)
This is a first step towards the structure we discussed here:
https://gist.github.com/stellaraccident/931b068aaf7fa56f34069426740ebf20

There are two primary goals:

1. Separate the core project (C++ dialects and conversions) from the
hard PyTorch dependencies. We move all such things into projects/pt1 as
a starting point since they are presently entangled with PT1-era APIs.
Additional work can be done to disentangle components from that
(specifically LTC is identified as likely ultimately living in a
`projects/ltc`).
2. Create space for native PyTorch2 Dynamo-based infra to be upstreamed
without needing to co-exist with the original TorchScript path.

Very little changes in this path with respect to build layering or
options. These can be updated in a followup without commingling
directory structure changes.

This also takes steps toward a couple of other layering enhancements:

* Removes the llvm-external-projects/torch-mlir-dialects sub-project,
collapsing it into the main tree.
* Audits and fixes up the core C++ build to account for issues found
while moving things. This is just an opportunistic pass through but
roughly ~halves the number of build actions for the project from the
high 4000's to the low 2000's.

It deviates from the discussed plan by having a `projects/` tree instead
of `compat/`. As I was thinking about it, this will better accommodate
the follow-on code movement.

Once things are roughly in place and the CI passing, followups will
focus on more in-situ fixes and cleanups.
2023-11-02 19:45:55 -07:00
Zhekun(Josh) Zhang 88d4c475d3
[Torch] Fix mixP case for non value semantic ops (#2540)
NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the
same as the first input. However, current implementation would result in
wrong result type for case like:

```python
a = torch.randn(3, 3).half() # float16
b = torch.randn(3, 3) # float32
a += b # i.e. torch.ops.aten.add_(a, b)
```
torch expects `a` to be float16, but dtype refinement would infer
float32 type, since it's replaced by `aten.add`.
2023-11-02 12:40:08 +08:00
Daniel Garvey 4901773f77
add uncovered cases in view lowering (#2524)
removes unecessary checks from empty strided
2023-11-01 21:56:44 -05:00
Ze Zhang 4279b750da
update AtenClampOp in torch-to-tosa to handle fp inputs (#2516)
As titled.

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2023-10-17 14:49:47 -07:00
Chi_Liu 14a4da923b
Update llvm-project to b44b3494f60296db6aca38a14cab061d9b747a0a (#2511)
The main purpose is to bring in the new mesh dialect change.
https://github.com/llvm/llvm-project/pull/68007
2023-10-16 19:29:48 -07:00
Ze Zhang f2c53b8ca5
Add aten.isclose support and its torch-to-tosa lowering (#2512)
Add aten.isclose op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests


To test e2e tosa lowering:
`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2023-10-16 09:44:53 -07:00
Ze Zhang e649e06b7b
Add aten.unflatten.int support and its torch-to-tosa lowering (#2509)
Add aten.unflatten.int op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests

To test e2e tosa lowering:

`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2023-10-13 18:39:41 -07:00
Quinn Dawkins 6f81ad7293
[TorchToLinalg] Improve broadcast lowerings in strict symbolic modes (#2505)
With strict symbolic shapes, we can assume numpy-style dynamic
broadcasts never occur. This improves the lowering in the presence of
this assumption.
2023-10-05 15:15:26 -04:00
Quinn Dawkins ae72eec224
Improve aten.broadcast_to folder when in strict symbol mode (#2504)
Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts
never occur. This allows us to strengthen the folder for broadcasts to
cases where the rank is the same and all shapes match (including dynamic
sentinel values).
2023-10-05 09:02:10 -04:00
Stella Laurenzo 860be09a39
Elide dynamic broadcast checks when in strict symbolic shapes mode. (#2496)
When importing dynamic shaped programs from Dynamo, via torch.compile or
torch.export, we can assume that strict symbolic shape checks have been
done prior to generating torch IR. Among other shape checking, this
eliminates the case where an unknown dimension can be dynamically '1' in
a way that signals a broadcast.

Adds a `isAssumingStrictSymbolicShapes` utility which consults a
`torch.assume_strict_symbolic_shapes` attribute on an enclosing scope
and returns true if present.

In the linalg pipeline, many runtime checks are elided when this returns
true.
2023-09-29 16:45:48 -07:00
Stella Laurenzo a00a0d4bfb
Integrate llvm-project and mlir-hlo. (#2454)
Corresponding commits:

* mlir-hlo: 16886a108eff5197f816ca0f1950cc5ff1b078d9
* stablehlo: 77a59815a82b34f7b08ed2d42a711d9920682d0e
* llvm-project: 4acc3ffbb0af5631bc7916aeff3570f448899647

* Adapt to ByteCodeOpInterface changes.
* Adapt to RegionBranchPoint changes: https://reviews.llvm.org/D159116
* Adapt inferReturnTypes to get the value from properties.
* Adapt invalid.mlir to properties syntax
* [TOSA] Align with custom assembly format change.
* [TOSA] handle change of axis to int32 type
* [TOSA] Restore improper convert to i32

Landing with Windows broken (it cannot be fixed because of the way the mlir-hlo dep is inserted). Will followup with an untangling.
---------

Co-authored-by: TatWai Chong <tatwai.chong@arm.com>
Co-authored-by: Eric Kunze <eric.kunze@arm.com>
2023-09-12 15:09:57 -07:00
Bruce Kim cd1c7df8be
[MLIR][TORCH] Add E2E support for view_as_real op (#2419)
* view_as_real test case, allow dtype in testutils.randn

* abstract python upstream func implemented

* fixed upstream dtype func, implemented view_as_real backend op

* formatted AtenViewAsRealOp, removed change in e2etest/framework

* removed test suit from reshape_like.py, because it's moved to basic.py

* implemented C-API wrapper for mlirComplexF128 type

* fixed torch.complex dtype width in MLIR and Torch MLIR, deleted float16 dtype dict

* Changed IR input of aten fft_fft unit test

* code refactored

* code refactored and fixed ci test

* refactored: removed white spaces, and rolled back to having both input/output affine expr

* refactored: deleted output affine expr to reduce redundancy

* xfail ltc backend

* removed ComplexImag and ComplexReal from torchdynamo xfail set

* copied and pasted from main branch as there's no change to be made in this file

* refactored abstract_interp_lib_gen.py

* refactored: torchtypes.td, formatted, removed commented out code
2023-09-01 21:12:01 -07:00
Quinn Dawkins 1fc4314b62
Add folder for aten.broadcast_to on unchanged static shapes (#2421) 2023-09-01 14:50:34 -04:00
JianzheXiao 17d02811d5
[Torch Dialect] add folder for aten.any.bool (#2388)
* update

* update

* update

* update

* update

* update

* update
2023-08-30 17:29:03 +08:00
jinchen62 1682b540bf
Prototype passes for lowering quantized group matmul (#2402)
* Support brevitas custom op (#2320)

* f16 change for brevitas

* Adapt the change of brevitas quant custom op name

* Add unit tests

* Make brevitas conversions isolated

* Address the comments

---------

Co-authored-by: dan <danimal197@gmail.com>
2023-08-29 21:25:45 -07:00
Jiawei Wu 4c9d234b01
revert canonicalizer for PrimListConstructOp (#2408) 2023-08-22 09:18:39 +08:00
Jiawei Wu 60bad54f27
[Torch Dialect] replace none-index in aten.Index.Tensor's param by manually generating it (#2344)
* [Torch Dialect] replace none-index in aten.Index.Tensor's  param by manually generating it
Co-authored-by: Jiawei Wu <wujiawei.aml@bytedance.com>
Co-authored-by: Jianzhe Xiao <jianzhe.xiao@bytedance.com>

* minor typo fix

* add new failed e2e tests for ltc

* fix typo

* Address comments

* Add more e2e tests

* add failed e2e tests for LTC

* address comments

* remove decomposition for AtenIndexTensorHackedTwinOp
2023-08-15 19:36:08 +08:00
Ramiro Leal-Cavazos ff762100b8
Add handling of namespaces to library generator (#2391)
When using custom ops, sometimes PyTorch will insert namespaces to the
abstract interpretation function name in the format:
`__torch__.{namespace_1}.{namespace_2}...{op_name}`.  The extra
namespaces are not part of the abstract interpretation function name,
so it needs to be removed before generating the library of MLIR
snippets of abstract interpretation functions. This commit adds
support for removing the namespace information.
2023-08-11 09:56:19 -07:00
Jiawei Wu 4c12aceb81
[Torch-Dialect] add canonicalizer for prim::ListConstruct op (#2306)
[Torch-Dialect] add canonicalizer for prim::ListConstruct op
2023-08-08 10:28:11 +08:00
Gleb Kazantaev fb52a73cbe
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
2023-08-02 10:29:11 -04:00
Matthias Gehre 0a67411719
test/CAPI/CMakeLists.txt: Depend on FileCheck (#2329)
I saw test failing when FileCheck wasn't already build
2023-07-25 10:11:55 +02:00
Matthias Gehre c56cb531d5
Ignore constants in the legality error (#2328) 2023-07-25 10:11:40 +02:00
Jiawei Wu 026e8db2e4
[Stablehlo] add converter for aten.scatter.src op (#2295) 2023-07-24 10:14:45 +08:00
Alexandre Rames 1e468e8294 Fix canonicalization of `torch.prim.TupleUnpack`. 2023-07-20 20:08:46 +02:00
Alexandre Rames a20422ce65 Support `DerefineOp` in `RefinePublicReturn`. 2023-07-20 20:08:46 +02:00
Alexandre Rames 4847563bed Clean up verification of calling conventions.
The implementation at this place was a remnent of the times the pipeline was
run only once.
Rely instead on the backend verification, after optimizations have had an
opportunity to resolve some uncertainties. (e.g. `!torch.optional`).
2023-07-20 20:08:46 +02:00
Matthias Gehre 64d7626a52
Fixes for split tensor and slice (#2314)
* RecomposeComplexOps: Remove dead slice op

* lib/Dialect/Torch/IR/TorchOps.cpp: Fold slice ops even when they are on non-value tensors

* lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix slice start/end out of range/none

* lib/Dialect/Torch/IR/TorchOps.cpp: AtenSliceTensorOp::fold: Fold slices that go from 0:int_max

* More tests for aten.split.Tensor
2023-07-20 09:53:54 +02:00
Jiawei Wu 3f843c8fd9
[torch-dialect] fix aten.type_as op's folder (#2283)
[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtype
2023-07-20 09:51:58 +08:00
Ramiro Leal-Cavazos 718f53ff8a
Fix handling of `!torch.number` in abstract interpretation library (#2309)
In PyTorch, the `NumberType` is equal to `Union[int, float,
complex]`. However, the abstract interpretation library was treating
the `NumberType` as `Union[int, float]`, resulting in type mismatches
when reifying certain dtype functions. This commit fixes the type
inconsistency by having the abstract interpretation functions take as
an input a `Union[int, float, complex]` for the ops that take
`!torch.number` inputs.
2023-07-17 09:52:04 -07:00
Jiawei Wu c7fa42b7d3
[Torch Dialect] Add canonicalizer for aten.to.other op (#2273)
Canonicalize aten.to.other to prim.device + prim.dtype + aten.to.device
Co-authored-by: wujiawei.aml <wujiawei.aml@bytedance.com>
2023-06-30 09:43:08 +08:00
Yuanqiang Liu 449cfb8375
[Torch Dialect] add more scalar op folders (#2265) 2023-06-29 10:37:13 +08:00
Yuanqiang Liu 1ea2b57ab7
[Torch Dialect] add folder for aten.add (#2264)
* [Torch Dialect] add folder for aten.add

* update

* update

* update
2023-06-27 10:55:28 +08:00
Yuanqiang Liu 96b14e952e
[Torch Dialect] Support aten.device.with_index (#2254) 2023-06-23 01:07:14 +08:00
Vivek Khandelwal f6a6cfea4e
[MLIR][TORCH] Add support for negative index values for index.Tensor op (#2233)
This commit adds the support for index.Tensor op when the index values
are negative. This commit wraps around the index values by checking
their values at run time.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-06-16 14:21:04 -05:00
Yuanqiang Liu 7c6961bcbf
[Torch Dialect] Support aten.cuda and add canonicalizer for aten.cuda (#2231) 2023-06-14 09:56:39 +08:00
Maksim Levental 0caaf8d32a
Bump LLVM (#2176)
* Bump LLVM

---------

Co-authored-by: Matthias Gehre <matthias.gehre@xilinx.com>
2023-06-13 16:17:23 +02:00
Yuanqiang Liu ddea56a832
[Torch Dialect] fix torch.uint8's dtype infer (#2227) 2023-06-13 10:38:20 +08:00
Christopher McGirr b461daa06e
fix(TorchToTosa.cpp): adjust torch->tosa div conversion (#2200)
check the return type of the division to figure out whether to use
the floating point implementation of a division or to use the integer.

the issue rose from the fact that the inputs are all integer but the
result was casted to floating point. The conversion then chose to
use the integer implementation of division which is not legal in tosa
when all the inputs get casted to floating point.

fix(TorchToLinalg): AtenDivScalarOp

upcast self operand as well if applicable, the self operand must also
be casted to float as it can be an integer.
2023-06-12 11:18:38 +02:00
Matthias Gehre 27a3d09917
Torch: Fold RuntimeAssertOp when condition is true (#2198) 2023-06-09 19:06:25 +08:00
Yuanqiang Liu 5a7bf4e4cb
[Torch Dialect] Add canonicalize pattern for aten.is_floating_point (#2194)
* [Torch Dialect] Add canonicalize pattern for aten.is_floating_point

* implement as fold

* add lit test
2023-06-07 17:05:31 +08:00
Vivek Khandelwal da886280fe
[MLIR][TORCH] Add E2E support for aten.tril op (#2202)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-06-05 16:17:01 -07:00
Ramiro Leal-Cavazos dff3405d5a
Add alias analysis for cast-like ops to maximize-value-semantics (#2160)
When `use_tracing=True` is used to import a model into Torch-MLIR,
several casts get inserted in the IR to bridge the untyped inputs and
outputs with the typed body of the computation. These casts create
extra aliases of tensors that cause the current analysis in
`maximize-value-semantics` to fail.

In particular, the `maximize-value-semantics` analysis assumes that the
only valid alias right after an overwrite is the overwritten
alias. So, if there is a use of a casted version of the overwritten
alias after the overwrite, the analysis fails.

This commit improves the analysis by identifying all cast-like aliases
of the overwritten alias and allowing such aliases to be used after an
overwrite.

Because this issue only arises when using tracing, it cannot be
currently tested e2e, so only lit test is added.
2023-05-25 17:05:41 +00:00
Zhekun Zhang f0b7b63be0
[Stablehlo] Add aten.uniform lowering (#2101)
* add uniform stablehlo lowering

* add unit test

* new line

* rm redundant file

* Empty commit, trigger test

* fix include

* address comments

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
2023-05-25 10:32:55 +08:00
TatWai Chong ed4ecb072f
[tosa] support lowering basic torch binary ops with mixed dtypes (#2122)
Lowering torch operations that allow different compatible data types
in its operands to tosa end up generating invalid tosa IR with mixed
data types. In tosa spec, certain operations (generally element-wise
operations) require all operands to have the same data type.

Add wrapper functions for those element-wise tosa ops to perform op
creation with type conversion if necessary.
2023-05-18 17:12:18 -07:00
Ramiro Leal-Cavazos de02b56e17
Replace RefineTypes with dtype functions (#2105)
This commit adds dtype functions for all the torch ops that did not
previously have one and removes the pass `RefineTypes`, since the
abstract interpretation library now takes care of all the dtype
propagation.

All dtype functions added are tested except for
- `aten.embedding`
- `aten._embedding_bag`
- `aten.embedding_bag`

These functions need a change to the testing framework to allow
specifying the actual data inside the tensor used for testing. I will
fix this in a follow up patch.

Co-authored-by: Jiahao Li <liplus17@163.com>
2023-05-12 13:40:45 -07:00
Sean Silva d7614c261d Integrate LLVM
LLVM: 26ee8947702d79ce2cab8e577f713685a5ca4a55
MHLO: 4805d8498dfb81566076f56f52273b426c1cc5bf

Per: https://github.com/llvm/torch-mlir/issues/1178#issuecomment-1538492185
2023-05-09 10:14:27 -07:00
Chi_Liu 51e0a2c933
[Stablehlo] Add stablehlo support for aten.abs (#2068)
Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
2023-05-08 22:13:00 -07:00
Yuanqiang Liu ef6dae6ae2
[Linalg] fix lowering reduce max with -inf (#2097) 2023-05-08 09:17:49 -07:00