Commit Graph

721 Commits (1d6aca3823e33a026320dadfd4d680452758edfa)

Author SHA1 Message Date
Quinn Dawkins 494089d53d
Clang format refresh (#2812)
After noticing a number of commits with unrelated formatting changes,
I think something was changed with clang-format at one point and we're
seeing a number of unrelated changes. Doing a refresh can help avoid
this.

The changes made here came from
```
find lib -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find include -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find projects -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
```
2024-01-29 12:59:33 -05:00
Rob Suderman 2ef228328f
[torch] `torch.dequantize` for per channel tensors to` linalg` (#2769)
Support a lowering for dequantization for per channel tensors from
`torch` dialect to a linalg decomposition. Tested via a numerical
`torch` test.
2024-01-25 16:40:21 -08:00
Aart Bik e824fbc65c
[torch-mlir][torch] add encoding field to torch type (#2799)
This adds an encoding field to the torch type, using the interfaces for
printing, parsing, and verification. Note that although this change
prepares adding sparsity to the torch type (as illustrated by the round
trip and invalid tests), nothing in this change depends on the actual
contents of the encoding field!
2024-01-25 10:04:04 -08:00
Rob Suderman f6f890520b
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
2024-01-24 14:02:50 -08:00
zjgarvey c531f5495b
AtenAdaptiveMaxPool2d Conversion to Linalg (#2779)
The logic here is very similar to the conversion for AdaptiveAvgPool1d
#2661 with a few modifications:

1. buffVal = -inf instead of 0
2. the main linalg generic op accumulates a max, instead of a sum, to
the first output tensor
3. avg pooling requires dividing the sum pool by the kernel width, which
we stored as an auxilliary tensor (kSizeTensor). Here, the auxiliary
tensor will be recording the indices. Strangely enough, the only
signature available for this function is to return indices, and it
appears that they must be computed whether the user desires them or not.
See
[pytorch/torch/nn/functional.py](https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L1174).

Before writing other adaptive pooling conversions, the logic of this
decomposition should be rolled into a helper function that will work for
both max and avg pooling ops. Even the auxiliary tensor should likely be
automated. This code was written in a slightly more tedious way than
strictly necessary (often using loops to fill SmallVectors up to rank-2,
which is only two in this case), in order to more easily facilitate the
transition to a helper function.
2024-01-24 09:09:56 -08:00
Xida Ren (Cedar) ccaac85788
implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757)
convolution with [time,batch,channel] ordering, as opposed to the
default [batch, channel, time]. Currently implementing by transposing
the input and output, but may need to get its own implementation in the
future because this is supposed to be an op that gives a speedup. This
is used by fairseq
(https://github.com/facebookresearch/fairseq/issues/172).

(in case you were wondering like me, this is different from transposed
convolution. Transposed convolution has fractional strides).

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Frederik Harwath <frederik.harwath@amd.com>
2024-01-23 21:30:03 -08:00
Franz Haniel b9806cfa38
[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632) 2024-01-22 12:47:13 -05:00
Ze Zhang 77a03f2069
torch-to-tosa lowering support for AtenLinalgVectorNormOp (#2734)
This PR add torch-to-tosa lowering support for AtenLinalgVectorNormOp

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-01-18 12:32:23 -08:00
lonely eagle f85e5c932b
[Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num (#2743) 2024-01-16 14:29:34 +08:00
lisaliu1 09421b1cf3
[TorchToLinalg] Add lowering for aten.replication_pad2d (#2715)
Co-authored-by: Lisa Liu <lingl@xilinx.com>
2024-01-15 14:02:27 -05:00
Rob Suderman 197b3b475c
[onnx] Convert `onnx.constant` to `torch` literal tensor (#2748)
Handles the multiple cases of `onnx` constant values and converts them
to `torch` literal tensors. This can include splats with a single
integer or floating point value, a set of explicit integer values, or
an elements array attr of values.
2024-01-15 09:31:22 -08:00
Rob Suderman dc37616d67
[torch][quant] Support quantize and dequantize for torch (#2731)
Handle both `torch.dequantize` and `torch.quantize_per_tensor` including
the op based quantization parameter tracking. This includes adding
`qint32` to torch types as it was missing during the initial type
inclusion.

For testing we only have `torch.int8` and `torch.float` types on
function boundaries as the `qint8` types require passing the scale
and zero point quantization information which is not supported yet.
2024-01-12 19:11:14 -08:00
Chi_Liu c7452af4fa
[MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695)
Add Maxpool ONNX op support.
Add Utils.h/cpp files to create a constant int list for ONNX.
2024-01-12 14:54:38 -08:00
James Newling 47ffc90db4
signed/unsigned c++ compiler warning fixes (#2742) 2024-01-11 09:46:46 -08:00
Ilija Kalinić e1a86e480a
Implement lowering of torch.aten.logit (#2697)
Closes nod-ai/SHARK-Turbine#290
2024-01-11 20:25:42 +05:30
Frederik Harwath 0860c41ee2 Implement aten.reflection_pad2d lowering to linalg 2024-01-10 21:32:22 -10:00
Vivek Khandelwal 208ae35583 [MLIR][ONNX] Add TorchToOnnx Support for DepthToSpace op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 17:50:47 +05:30
John Wu 4e5e34d215
[MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696) 2024-01-03 19:41:10 -08:00
Andreas Falkenberg 80bd093d56 Added tensorResultTypeAtIndex to Patterns.h
Need this for LayerNorm
2024-01-03 11:08:36 +05:30
kumardeepakamd 9adad9bc40
Add support for reflection_pad1d (#2706)
Adds a lowering to Linalg for reflection_pad1d. Based on ideas/code from draft PR
https://github.com/llvm/torch-mlir/pull/2693.

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
2024-01-02 14:05:11 -05:00
Sungsoon Cho 8e389ff2ff
Implement lowering of torch.aten.exponential (#2680)
https://github.com/llvm/torch-mlir/issues/2646

Decompose aten.exponential() into: -exp(1-x)/lambda
2023-12-27 20:33:18 -08:00
aldesilv 2d796b7502
lower onnx max op to torch aten maximum op (#2618)
lower onnx min op to torch aten minimum op
2023-12-27 11:07:35 -08:00
John Wu 46f2cb50dc
[onnx] Lower onnx.HardSigmoid to torch (#2682)
The expression for HardSigmoid in Onnx
(https://onnx.ai/onnx/operators/onnx__HardSigmoid.html): max(0, min(1,
alpha * x + beta))

is inherently different from HardSigmoid in Torch
(https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html)
which is: if x < -3 -> 0
elif x > 3 -> 1
else x/6 + 1/2

That being said, it was just better to compute out the entire expression
when translating the Onnx expression to Torch mlir, which is done in
this PR. Some of the logic is shared from the files in
`DecomposeComplexOps`. Therefore, refactored some shared logic between
`DecomposeComplexOps` and `DefaultDomainGToP` and put it in a `Utils`
file.
2023-12-21 07:29:22 -08:00
Rob Suderman 11cc92d4ab
[onnx] Lowerings from `onnx.tan` (#2642)
Started work on the `tan` lowerings for ONNX to Torch. Uses `sin` and
`cos` to represent a `tan`.
2023-12-20 10:09:39 -08:00
Rob Suderman 61888690bb
[onnx] Add support for `onnx.sinh` (#2643)
Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding
the `aten.sinh` operator.
2023-12-15 21:23:51 -08:00
Rob Suderman 705ea958ae
[onnx] Lowerings from `onnx.transpose` (#2641)
Lowerings for `transpose` from ONNX to `aten`. Implementation depends on
making multiple `aten.transpose` operations swapping pairs of dimensions.
As `onnx.transpose` can swap around any dimensions it may require
constructing multiple `aten.transpose`.
2023-12-15 15:30:05 -08:00
Quinn Dawkins 030b0140d4
[TorchToLinalg] Lower aten.cat to tensor.concat (#2650)
This replaces the lowering of aten.cat with tensor.concat, allowing more
efficient handling of concatenations in downstream flows. The refbackend
populates concat decomposition patterns that can be used to recover the
previous lowering.
2023-12-15 15:45:32 -05:00
Rob Suderman 061af696ce
[onnx] Lowering for `onnx.shape` to `torch` and `tensor` (#2648)
Includes the lowering from the `aten` equivalent to `tensor` operations.
2023-12-15 11:37:49 -08:00
Sungsoon Cho 55e9401c5c
Implement lowering of aten.cosh op. (#2635) 2023-12-15 11:19:26 -08:00
Gaurav Shukla eb9249e601
[ONNX][MLIR] Add support for LeakyRelu and GatherElements op (#2655)
This commit adds support for `LeakyRelu and GatherElements` op in the
onnx pipeline.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-12-15 11:18:28 -08:00
saienduri f59c01fd2f
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-15 09:36:18 -08:00
Rob Suderman 4857606ffe
[onnx] Lowerings from `onnx.selu` (#2634)
Lowerings for `selu` lowerings for ONNX to the corresponding torch
implementations. Torch's `selu` implementation has fewer features so
we use the a generalized `elu` with the input scale set to `1.0`.
2023-12-14 08:53:47 -08:00
JianzheXiao 6ddeb1a6ef
[torch] Add support for aten.selu (#2640)
Add `aten.selu` operation to `torch` dialect.
2023-12-13 20:28:08 -08:00
JianzheXiao 7cf52ae73f
[Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591)
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
2023-12-13 11:05:12 +08:00
Stella Laurenzo 74f7a0c9d6
Upstream the ONNX importer. (#2636)
This is part 1 of 2, which will also include upstreaming the FX
importer. I started with ONNX because it forces some project layout
updates and is more self contained/easier as a first step.

Deviating somewhat from the RFCs on project layout, I made the following
decisions:

* Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks
already has opened up that namespace and it seemed to fit. Better to
have fewer things at that level.
* Setup the build so that the root project only contains MLIR Python and
pure Python deps (like the importers), but this can be augmented with
the `projects/` adding more depending on which features are enabled.
* The default build continues to build everything whereas in
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a
`torch-mlir-core` wheel with the pure contents only.

`onnx_importer.py` and `importer_smoke_test.py` are almost verbatim
copies from SHARK-Turbine. I made some minor local alterations to adapt
to paths and generalize the way they interact with the outer project. I
expect I can copy these back to Turbine verbatim from here. I also
updated the license boilerplate (they have the same license but slightly
different project norms for the headers) but retained the correct
copyright.

Other updates:

* Added the ONNX importer unit test (which also can generate test data)
in lit, conditioned on the availability of the Python `onnx` package. In
a followup once I know everything is stable, I'll add another env var
that the CI can set to always enable this so we know conclusively if
tests pass.
* Moved the ONNX conversion readme to `docs/`.
* Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` ->
`TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the
JitIR importer and LTC options `cmake_dependent_options` for robustness.
2023-12-12 19:02:51 -08:00
Frederik Harwath b656c674ee Implement e2e support for aten.acos op
This depends on a change in the LLVM core repository which adds acos
support to the MLIR Math dialect.
2023-12-12 10:52:02 +01:00
Sambhav Jain 7acabafd84
Remove folder from `AtenStackOp` for single element list inputs (#2626)
`AtenStackOp` defines this folder for list operand containing single
element:
```
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
  auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
  if (!list || !list->hasOneUse() || list.getElements().size() != 1)
    return nullptr;
  return list.getElements()[0];
}
```
However, unlike `AtenCatOp`, `AtenStackOp` cannot be folded away for
single element list operand because the result from a stack operation
contains an additional dimension (of size 1, like expand_shape).

This PR removes the `AtenStackOp::fold` method, and adds an e2e test for
single element list input case, which fails on current `main` as
follows:
```
Unexpected outcome summary: (linalg)                                                                                                                                                                   
                                                                                                                                                                                                       
****** Failed tests - 1 tests                                                                                                                                                                          
    FAIL - "TensorsStackSingleElementListModule_basic"                                                                                                                                                 
        @ trace item #0 - call to "forward"                                                                                                                                                            
        @ output of call to "forward"                                                                                                                                                                  
        ERROR: shape (torch.Size([10, 32])) is not equal to golden shape (torch.Size([10, 1, 32]))     
```
Thanks Chris Lalau Keraly for the bug report.
2023-12-11 10:52:50 -08:00
Vivek Khandelwal 0b4422a253 [MLIR][ONNX] Add OnnxToTorch support for bitwise and math ops
This commit adds the OnnxToTorch support for BitwiseXor, BitwiseOr, Div, Equal, Cast,
Ceil, Floor, Cos, and Clip op.
This commit also adds the TorchToLinalg support for aten.clamp.Tensor and aten.clamp_min.Tensor op.

Signed-Off By: vivekkhandelwal1424@gmail.com
2023-12-11 19:36:01 +05:30
JianzheXiao 96fcde4d77
[Torch Dialect] Support Einsum Op (#2230)
As title, support torch.aten.einsum op

Right now only support Static Shape, because of the known issue, the
fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154

Co-authored-by: Jiawei Wu
[wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com)
2023-12-10 12:30:37 +08:00
Stella Laurenzo 8252656b6d
Advance llvm-project and stablehlo. (#2619)
llvm-project: bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df
stablehlo: ab709fe48de88c67717abfbd7ef17425eb95ddaf

These commits were chosen in order to account for an MLIR API break from
3dbac2c007
which required a patch to stablehlo. We integrate a bit beyond that
commit to deal with some revert/reapply cycles in the intervening range
which were discovered in another downstream.

Further, it requires adaptation to the stablehlo API breaks introduced
from https://github.com/openxla/stablehlo/pull/1872 which are along for
the ride.

Since some stablehlo builders were changed to directly take int64_t
array refs, also traced that up some call stacks to eliminate some
signed/unsigned mismatches that result.

Also adds a few TOSA tests to the passing set that seem to work now.
2023-12-07 23:13:42 -08:00
Frederik Harwath 6244f301fb
Regenerate GeneratedTorchOps.td after recent change to torch_ods_gen.py (#2612)
Try to fix the error reported by @qingyunqu in #2609.
2023-12-05 08:04:32 -08:00
Quinn Dawkins 400752ca8d
[TorchToLinalg] NFC: Move Utils.h to an externally accessible location (#2603) 2023-12-01 19:38:21 -05:00
Ramiro Leal-Cavazos e568f7e999
Move handling of integer signedness to the backend conversions (#2597)
The function `getTypeForScalarType` currently takes an argument to
specify the signedness of integer types. This is leakage of backend
specific requirements into the torch dialect world. Because
`getTypeForScalarType` is a utility function for the torch dialect, it
should only produce types that match the sign conventions used by
PyTorch (regular integers are signed and unsigned integers are
unsigned).

This commit removes the signedness argument from
`getTypeForScalarType`, and moves the backend specific handling of
integer types to the backend code.
2023-11-29 09:43:09 -08:00
Vivek Khandelwal dc9ea08db5 [MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops
This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd,
and BitwiseNot op.
This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp.

Signed-Off By: vivekkhandelwal@nod-labs.com
2023-11-28 17:19:07 +05:30
Stella Laurenzo e06efc5136
Initial TorchOnnxToTorch conversion pipeline. (#2585)
Adds a pipeline to convert custom ops and metadata represented as
`torch.operator` custom ops to corresponding `torch` ops where possible.

This is part of a multi-part approach for building ONNX import in as a
regular feature of torch-mlir. It is focused on the conversions vs the
infra. We will end up maintaining a [pure-python
importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
to go with this in torch-mlir, and we will also maintain test case
generation utilities derived from it.

I have left substantial documentation in the README of the conversion
directory, including the recommended approach that we will take to keep
building this out.

(note that this organizes the code to coincide with the refactoring in
#2442 versus the current flat arrangement)
2023-11-21 21:02:55 -08:00
James Newling 03e8f99730
Lowering to linalg of prims split_dim op (#2576)
Adds support for lowering to prims split_op. 

Similar design to collapse op lowering in 
https://github.com/llvm/torch-mlir/pull/2572, with some 
small differences, because the split_dim op (in pytorch) is
view-changing whereas the collapse is not. The difference 
means that 

1) it must be registered in the function Torch::isViewLikeOp
2) it must be be added to the "expected fail" set for the torch dynamo backend.
2023-11-21 07:56:09 -08:00
Zhekun(Josh) Zhang d67afa9e95
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543) 2023-11-21 13:26:17 +08:00
Stella Laurenzo 5eae0adff1
Breakup python pytorch deps (#2582)
This lifts the core of the jit_ir_importer and ltc out of the pt1
project, making them peers to it. As a side-effect of this layering, now
the "MLIR bits" (dialects, etc) are not commingled with the various
parts of the pt1 project, allowing pt1 and ltc to overlay cleanly onto a
more fundamental "just MLIR" Python core. Prior to this, the Python
namespace was polluted to the point that this could not happen.

That "just MLIR" Python core will be introduced in a followup, which
will create the space to upstream the FX and ONNX pure Python importers.

This primary non-NFC change to the API is:

* `torch_mlir.dialects.torch.importer.jit_ir` ->
`torch_mlir.jit_ir_importer`.

The rest is source code layering so that we can make the pt1 project
optional without losing the other features.

Progress on #2546.
2023-11-19 12:10:19 -08:00
James Newling dad1f012f6
Add verification for torch permute op (#2551)
- adds support for an optional verifier to the generated torch op
tablegen (GeneratedTorchOps.td)
- uses the above to add a verifier for the torch permute op. 

Motivation: I hit an unclear error from linalg while developing a
decomposition pass for pixel_shuffle. The error would have been clearer
if the problem had been detected earlier in the invalid aten.permute op.

Testing: new tests added. To run added tests, from the base directory
run

```
 ./build/bin/llvm-lit  test/Dialect/Torch/invalid.mlir
 ```
2023-11-15 11:47:54 -08:00
James Newling e81282ae8f
Support for prims collapse op (lowering to linalg) (#2572)
Steps taken:
1) add generator code to torch_ods_gen.py, run update_torch_ods.sh
2) add (custom) shape and type inference generator code to
abstract_interp_lib_gen.py, run update_abstract_interp_lib.sh
3) Implement lowering to tensor.collapse_dims. Requires the `start` and
`end` values to be constant, else lowering fails
4) Update xfail_sets.py (append to LTC_XFAIL_SET) after running
/tools/e2e_test.sh --filter Collapse --verbose -c XX for all support
backends (XX).

Motivation: 
- Supporting the collapse operation will be useful for lowering of
pixel_shuffle (see Issue #2559)
2023-11-15 08:34:38 -08:00