Commit Graph

2498 Commits (962d5143085b2bea7db0c0e9bdc26bf5ea8db2b5)
 

Author SHA1 Message Date
Xida Ren (Cedar) ccaac85788
implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757)
convolution with [time,batch,channel] ordering, as opposed to the
default [batch, channel, time]. Currently implementing by transposing
the input and output, but may need to get its own implementation in the
future because this is supposed to be an op that gives a speedup. This
is used by fairseq
(https://github.com/facebookresearch/fairseq/issues/172).

(in case you were wondering like me, this is different from transposed
convolution. Transposed convolution has fractional strides).

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Frederik Harwath <frederik.harwath@amd.com>
2024-01-23 21:30:03 -08:00
Chi_Liu 77ae56337d
[ONNX][MLIR] Add support for onnx.Exp op (#2792)
https://github.com/nod-ai/SHARK-Turbine/issues/312
2024-01-23 13:45:00 -08:00
James Newling dc056e58e6
[MLIR][TORCH] Add onnx.cast cases used by OPT-1.25M (#2787) 2024-01-23 21:06:25 +05:30
Vivek Khandelwal c9d8ffb414
build: manually update PyTorch version (#2788)
Set PyTorch and TorchVision version to nightly release 2024-01-22.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-23 21:05:19 +05:30
Gaurav Shukla b7a0329676
[ONNX][MLIR] Fix padding size constraint for onnx.maxpool op (#2782)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2024-01-23 19:23:01 +05:30
Dave Liddell d452c4f4c0
Fix onnx importer to treat Constant values as static (#2780)
Fixes  https://github.com/llvm/torch-mlir/issues/2764

In the case of OPT, there are ConstantOfShape ops whose input shape is
not static (that is, an initializer), but rather comes from a Constant
op. The importer can't handle such non-static input shapes.

The fix here is to create initializers for a subset of Constant ops
(ones with "value" attributes), so that their outputs can be used
statically. Additionally, there was no case for creating a splat of
int64, so I added that as well.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-01-22 13:00:05 -08:00
Chi_Liu cad98e8113
[ONNX][TORCH-MLIR] Add TopK support (#2774)
https://github.com/nod-ai/SHARK-Turbine/issues/331
2024-01-22 12:56:39 -08:00
Ramiro Leal-Cavazos 5883ef0f21
Fix unused variable warnings (#2775) 2024-01-22 11:05:55 -08:00
Srinath Avadhanula 73b30604da
Do not try to legalize transposed convolution (#2721)
Currently transposed convolution is not handled correctly by
`TorchToTosa`. This PR allows transposed convolutions to pass through
the conversion so that they can be handled by other conversion passes
later in a pipeline.

An example input which produces a compilation error is:

```
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
  %true = torch.constant.bool true
  %int1 = torch.constant.int 1
  %int2 = torch.constant.int 2
  %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
  %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
  %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
  %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
  %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
  return %output : !torch.vtensor<[1,64,2,200],f32>
}
```

This MLIR produces an error about a cast operation with a size mismatch
when passed through `torch-to-tosa`:

```
 error: 'tensor.cast' op operand type 'tensor<1x64x1x50xf32>' and result type 'tensor<1x64x2x200xf32>' are cast incompatible
```

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-01-22 10:57:56 -08:00
Franz Haniel b9806cfa38
[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632) 2024-01-22 12:47:13 -05:00
James Newling 50ac3b1912
g++ build fix (#2778)
Introduced in 704cfdaf08 of @wu-s-john 

g++ compiler error: 

Pooling.cpp:177:13: error: explicit specialization in non-namespace
scope ‘class

Design looks good, g++ is just freaking out for no good reason.
Un-nesting the template classes fixes the error.

We don't have g++ CI. This hopefully happens infrequently enough that we
can just fix manually. My service to those folks who really like
building with g++... :)
2024-01-19 19:12:29 -08:00
Dave Liddell 2f4924015d
[onnx] Added flatten (#2760)
[https://github.com/nod-ai/SHARK-Turbine/issues/328](url)

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-01-19 16:18:16 -08:00
Scott Todd b3a3ad4e2a
Generalize install instructions to not exclude Windows. (#2771)
Overly specific docs can get stale easily. It looks like
https://llvm.github.io/torch-mlir/package-index/ has included Windows
packages since around https://github.com/llvm/torch-mlir/pull/1521.

Here's an example release:
https://github.com/llvm/torch-mlir/releases/tag/snapshot-20240118.1087

```
torch-2.3.0.dev20240109+cpu-cp311-cp311-linux_x86_64.whl
torch-2.3.0.dev20240109+cpu-cp311-cp311-win_amd64.whl
torch-2.3.0.dev20240109+cpu-cp38-cp38-linux_x86_64.whl
torch-2.3.0.dev20240109-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0.dev20240109-cp311-none-macosx_10_9_x86_64.whl
torch_mlir-20240118.1087-cp311-cp311-linux_aarch64.whl
torch_mlir-20240118.1087-cp311-cp311-linux_x86_64.whl
torch_mlir-20240118.1087-cp311-cp311-macosx_11_0_universal2.whl
torch_mlir-20240118.1087-cp311-cp311-win_amd64.whl
torch_mlir-20240118.1087-cp38-cp38-linux_x86_64.whl
```
2024-01-19 15:13:32 -08:00
Xida Ren (Cedar) 18669b38cb
Create add_ops.md (#2770) 2024-01-19 10:44:45 -08:00
Gaurav Shukla 3b85c70748
[ONNX][MLIR] Add support for onnx.gather op (#2726)
This commit adds support for gather op in the onnx pipeline.
https://github.com/nod-ai/SHARK-Turbine/issues/242

Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>
2024-01-19 21:58:29 +05:30
John Wu 704cfdaf08
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py

Also made some unit tests

Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).

To achieve this, we would need to lower our code using the `linalg`
dialect.


Turns out the pooling code in `linalg` looks like this.

```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
                        %strides: memref<3xindex>, %dilations: memref<3xindex>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
    %C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
    %D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
    %H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
    %W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>

    %kernel_d = memref.load %K[%c0] : memref<3xindex>
    %kernel_h = memref.load %K[%c1] : memref<3xindex>
    %kernel_w = memref.load %K[2] : memref<3xindex>
    %stride_d = memref.load %strides[%c0] : memref<3xindex>
    %stride_h = memref.load %strides[%c1] : memref<3xindex>
    %stride_w = memref.load %strides[2] : memref<3xindex>
    %dilation_d = memref.load %dilations[%c0] : memref<3xindex>
    %dilation_h = memref.load %dilations[%c1] : memref<3xindex>
    %dilation_w = memref.load %dilations[2] : memref<3xindex>

    linalg.generic {
        indexing_maps = [
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>,  // Map for input tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>,                                              // Map for kernel tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)>                                            // Map for output tensor
        ],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
        doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
    } ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
        ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
            %max_val = arith.maxf %input_elem, %output_elem : f32
            linalg.yield %max_val : f32
    }
    return
}

```

This was implemented based on it's source code with the adjustments
mentioned above:

4ca1b5e094/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (L5647)

Issues related to this can be found here

https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 21:09:46 +05:30
Ilija Kalinić faa4517e83
Implement lowering of torch.aten.remainder.Tensor (#2763)
Closes nod-ai/SHARK-Turbine#349
2024-01-19 18:09:08 +05:30
Andreas Falkenberg 4de4d38b87
Initial commit of NonZero op (#2766) 2024-01-18 15:23:13 -10:00
Rob Suderman b5387c0f29
[onnx] Lowering `onnx.dequantize_linear` to `torch` (#2759)
We can make the per-tensor version of the operation to the dequantize
operation via marking with the make quantized tensor component. This
introductions the `qint*` and `quint*` tensor type that can be lowered
to teh appropriate dequantization behavior during the torch-to-linalg
conversion.
2024-01-18 16:47:21 -08:00
Rob Suderman bd11877f6f
[onnx] Support lowering quantize linear to `torch` (#2751)
We can map the per_tensor case to the `torch.aten.quantize_per_linear`
operation. In this case we extract the `scale` and `zeropoint` values
and directly invoke the quantization, then return the integer
representation value.
2024-01-18 16:33:10 -08:00
Ze Zhang 77a03f2069
torch-to-tosa lowering support for AtenLinalgVectorNormOp (#2734)
This PR add torch-to-tosa lowering support for AtenLinalgVectorNormOp

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-01-18 12:32:23 -08:00
Phaneesh Barwaria eed144bfbc
[ONNX][MLIR] add Identity op support (#2754) 2024-01-16 19:06:54 +05:30
Sungsoon Cho a8538e1e3f
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737) 2024-01-15 22:49:29 -08:00
lonely eagle f85e5c932b
[Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num (#2743) 2024-01-16 14:29:34 +08:00
James Newling f78ec78ac8
Adjust bound check to be the same as PyTorch native (i.e. stricter) (#2755)
prims.expand expects the start and end dimensions to be strictly less
than the rank of the tensor.
2024-01-15 11:44:45 -08:00
kumardeepakamd 87389f0762
[ONNXToTorch] Add conversion for Onnx range (#2752)
Implemented ONNX.Range. The spec says the data type for start, limit,
delta are 0-D can be double, float, int16, int32, int64, All int types
mapped to !torch.int and all float types mapped to !torch.float

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
2024-01-15 14:26:46 -05:00
lisaliu1 09421b1cf3
[TorchToLinalg] Add lowering for aten.replication_pad2d (#2715)
Co-authored-by: Lisa Liu <lingl@xilinx.com>
2024-01-15 14:02:27 -05:00
Rob Suderman 197b3b475c
[onnx] Convert `onnx.constant` to `torch` literal tensor (#2748)
Handles the multiple cases of `onnx` constant values and converts them
to `torch` literal tensors. This can include splats with a single
integer or floating point value, a set of explicit integer values, or
an elements array attr of values.
2024-01-15 09:31:22 -08:00
Han-Chung Wang 10acea71be
Bump LLVM to llvm/llvm-project@0cb024b (#2753)
- Add fixes for
af78e5daf0
- Add fixes for
bb6d5c2200
2024-01-15 07:12:12 -08:00
Rob Suderman dc37616d67
[torch][quant] Support quantize and dequantize for torch (#2731)
Handle both `torch.dequantize` and `torch.quantize_per_tensor` including
the op based quantization parameter tracking. This includes adding
`qint32` to torch types as it was missing during the initial type
inclusion.

For testing we only have `torch.int8` and `torch.float` types on
function boundaries as the `qint8` types require passing the scale
and zero point quantization information which is not supported yet.
2024-01-12 19:11:14 -08:00
Chi_Liu c7452af4fa
[MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695)
Add Maxpool ONNX op support.
Add Utils.h/cpp files to create a constant int list for ONNX.
2024-01-12 14:54:38 -08:00
Ze Zhang 670a99ae19
Handle torch.none type in tosa.clamp op (#2739)
This PR updates the torch-to-tosa conversion with following changes:

- Support torch.none as min/max input argument for tosa.clamp op
- Support negative value as start index for tosa.slice op
- Add tosa.logical_or lowering support

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-01-11 10:36:48 -08:00
James Newling 47ffc90db4
signed/unsigned c++ compiler warning fixes (#2742) 2024-01-11 09:46:46 -08:00
Ilija Kalinić e1a86e480a
Implement lowering of torch.aten.logit (#2697)
Closes nod-ai/SHARK-Turbine#290
2024-01-11 20:25:42 +05:30
Andreas Falkenberg 5862854bc8
[ONNX][TORCH-MLIR] LayerNorm (#2716)
Layer Normalization using the torch.aten.native_layer_norm 

https://github.com/nod-ai/SHARK-Turbine/issues/325
2024-01-11 14:27:04 +05:30
Frederik Harwath 0860c41ee2 Implement aten.reflection_pad2d lowering to linalg 2024-01-10 21:32:22 -10:00
Xida Ren (Cedar) aee1fca251
Minor typo fix: in not implemented message for the exclusive and reverse attributes for cumsum (#2740) 2024-01-10 14:24:37 -08:00
kumardeepakamd 29569713f3
support for onnx.expand operator (#2729)
maps onnx.expand to torch aten broadcast_to, three tests added

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
2024-01-10 13:05:37 -08:00
Vivek Khandelwal 469c055190 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2024-01-09.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 17:51:00 +05:30
Vivek Khandelwal 208ae35583 [MLIR][ONNX] Add TorchToOnnx Support for DepthToSpace op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 17:50:47 +05:30
Vivek Khandelwal 4707d3bdc6 [MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 16:24:06 +05:30
Vivek Khandelwal 35e8f86792 [MLIR][ONNX] Add OnnxToTorch support for Dropout and Elu op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-10 16:23:55 +05:30
zjgarvey 07d0645f64
[RFC] general support for Adaptive Pooling Ops (#2661)
Adaptive pooling ops can only be decomposed into their non-adaptive
counterparts in trivial cases.

For example, the current decomposition for AtenAdaptiveAvgPool1dOp in
DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally
nothing), and outSize = 1 (i.e., do a batched average).

The reason adaptive pooling ops are difficult to lower to linalg is that
they are not constantly strided. They are computed by taking an input
tensor of shape (N, C, Hin), and an output size Hout, and computing the
output tensor at position (n,c, h) in the following way:

1. compute st(h) = (h*Hin)//Hout
2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout
3. apply a computation (max or avg) to the slice: INPUT[n, c,
st(h):en(h)]

The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp)
uses tensor.extract to access the input tensor inside the payload of a
linalg generic op. This is likely an unattractive use of linalg generic
ops, which is why I am asking for some more targeted feedback on the
validity of this approach before attempting to support the many other
adaptive pooling ops.

Specifically:

- Is the performance of this implementation bad enough to warrant
targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc.
- If the provided implementation is of acceptable performance to the
community, then is it permissable to remove the Adaptive pooling
decompositions from DecomposeComplexOps.cpp? Based on the current
structure of the -torch-decompose-complex-ops pass, it does not seem
possible to only decompose the adaptive ops in special cases (it seems
to get stuck in an infinite loop on a match failure). I would be happy
to instead incorporate the case logic into the conversion directly, and
remove the decompositions once they are rendered completely obsolete.

As long as this approach is acceptable, I can clean up the
implementation with some helper functions, and quickly add support for
each of the remaining Adaptive pooling ops.
2024-01-09 11:14:10 -08:00
Ben Vanik 4dd17f0b71
Fixing implicit double->float truncation warnings. (#2733)
Floating-point literals should use the correct type specifier.
2024-01-08 17:26:38 -05:00
Rob Suderman 985e7796a4
[linalg] Added `aten.clamp` support with integers to `torch-to-linalg` (#2718)
The lowering for `aten.clamp` did not support integer types. Added
support for integer types including a signed integer test.
2024-01-05 15:16:49 -08:00
Han-Chung Wang 6096fcb347
[OnnxToTorch] Delete unused variables. (#2728) 2024-01-04 17:30:05 -08:00
Kunwar Grover fb1dfa3126
Bump llvm-project to 6b65d79fbb4682468333cea42b62f15c2dffd8f3 (#2723)
Co-authored-by: hanhanW <hanhan0912@gmail.com>
2024-01-04 14:33:41 -08:00
Aart Bik aa7e95f7c8
[torch-mlir] remove trailing whitespace from e2e test files (#2727) 2024-01-04 14:09:12 -08:00
John Wu 4e5e34d215
[MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696) 2024-01-03 19:41:10 -08:00
Aart Bik 3e9bacdb51
[torch-mlir] update e2e test class documentation (#2722)
The doc seems copy-and-paste from the linalg-on-tensors class
2024-01-03 16:10:50 -08:00