- Add Torch to TOSA legalization for the following reduction ops:
+ aten.min.dim
+ aten.min
+ aten.max
+ aten.prod
+ aten.prod.dim_int
+ aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support aten.min.dim
legalization
- Update end-to-end tests sets in xfail_sets.py
Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0
Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Follow up cleanup for [this
PR](https://github.com/llvm/torch-mlir/pull/3689), which introduced a
decomposition for `aten.fmod.Tensor`. This means that the lowering for
this operator in linalg is no longer needed.
Thanks to @vivekkhandelwal1 for pointing this out.
---------
Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.
The tests added highlight the new capabilities introduced in this PR,
including:
Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes
Made sure that one cannot input both mask and is_causal.
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.
This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`
---------
Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
The lowering pattern for `aten.T` uses transposition implemented via
`linalg.generic`. For downstream passes it is advantageous to use named
ops wherever possible, so this patch changes the lowering to use
`linalg.transpose` instead.
Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.
This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:
```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```
Will be replaced by
`torch.aten.mul.int %input, %const-30`
And
```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`
This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.
To test:
`cmake --build build --target check-torch-mlir-all`
Supports the result with dynamic shape and scalar indices like
```
func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
```
`Torch::AtenSqueezeOp` is referring to the result shape, so it will
failed on lowering if the result shape is dynamic.
The current implementation uses a `linalg.generic` to broadcast the bias
tensor for the lowering of convolutions. This is suboptimal for later
pattern matching. This patch changes it to use the respective named op,
`linalg.broadcast`, instead.
The `axis` attribute is optionally available. Added support by computing
the pad based on the axis values.
---------
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
- This PR adds new (and equivalent) more tensorized impl of
MelWeightMatrix which lowers all the way to linalg.
- [Ref Pytorch
Impl](https://gist.github.com/PhaneeshB/4e6dfcded3007b1b686fbe28f07a67cd)
- Thanks to @rsuderman for pointing out the difficulties [earlier
impl](#3503) posed during lowering to linalg and also for providing a
better numpy impl 🙏
This commit adds the shape info for the tensors created during the
decomposition of GroupNorm op.
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
Set PyTorch and TorchVision version to nightly release 2024-08-18.
This commit also updates the `scaled_dot_product_attention` op.
A new attribute `enable_gqa` has been added. As of now, only the
default value for the same is supported.
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
Discovered in https://github.com/llvm/torch-mlir/issues/3104
Most likely when building with stablehlo, while waiting for it missing
dependency was generated to location shared with another dependency.
This commit extends the OnnxToTorch lowering for BatchNormalization op
for supporting the case when training=True.
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
The `layout` attribute was not considered for the `onnx.RNN` operation.
Added support for the attribute to transpose the inputs / outputs of the
RNN when valid.
The einsum lowering was missing the behavior for duplicate indices in
the equation. This amounts to a diagonalization along duplicate pairs of
indices in the equation.
Closes#3575
The PyTorch remainder operator is meant to compute the Python modulus
operator entrywise:
https://pytorch.org/docs/stable/generated/torch.remainder.html#torch.remainder
In python the modulus operator is meant to always return a result with
the same sign as the divisor:
https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations
In other words, torch.aten.remainder should return a Python-style
modulus instead of a C-style modulus. However the remainder operator was
simply translated into arith.ModSI or arith.ModF, which both effectively
compute the C-style modulus. Now the lowering has been modified so that
the modulus operator works properly with negative numbers, both in the
dividend, and the divisor.
This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.
Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
The saga of aligning onnx and torch padding conventions continues.
```python
onnx_pads = [low_x, low_y, low_z, high_x, high_y, high_z]
torch_pads = [low_z, high_z, low_y, high_y, low_x, high_x]
```
So not only is the lexicographical ordering hierarchy swapped (low/high
x spatial-dim -> spatial-dim x low/high) but the ordering in the the
spatial-dim specification is also reversed.
This patch properly reverses the pad ordering (and actually uses the
`shuffledPadding` to pad).
`onnx.Shape` can select only a subset of indices using attributes. Add
support for these attributes.
---------
Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
Following up from the discussion in
<https://github.com/llvm/torch-mlir/pull/3550>, I've edited the lowering
to prevent OOB extracts in a more direct fashion (i.e., just clamping
directly).
I don't think this affects the lit tests at all, but I've tested the
changes in our external test suite at
<https://github.com/nod-ai/SHARK-TestSuite/tree/main/>. I found the
issue when I was unexpectedly getting `nan`'s along the output image
border for a resize test there.
Change linalg.matmul_unsigned to linalg.matmul with unsigned type_fn
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
There were two issues related to `ignore_index` being set
(1) the onnx-to-linalg pass as not reading the value correctly (2) the
mean pass was not considering the `ignore_index` value
For (2) when taking the mean we need to know how many of the values were
considered in the sum and therefore we cannot divide by the total number
of elements. Adding a summation across the total number should correct
this issue.
The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:
1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.
2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
Torch has all scalars represented as i64 and f64 types which results in
extraneous trunc-extf commands. We can rework this by elliding
widen-narrow cases away.
Current StableHlo lowering strategy works well when `src` tensor's rank
is no bigger than `dst` tensor's. The new patch make it succeed in other
cases. The following is an example.
```
%190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor>
%191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32>
```
- Adds support for lowering depthwise + quantized convolution ops to
linalg::DepthwiseConv2DNhwcHwcQOp
- Changed the variable name for groupSize (which is really C/G) to the
more appropriate numGroups (G).
- Discovered in e2e testing that linalg does not accept (Cin = groups &&
Cout = K*groups for K>1) as a "depthwise" conv, so this also updates the
case-checking to reflect this issue.
Pytorch and ONNX apparently round to nearest, ties go to nearest even,
but we were using `math::round` for the torch-to-linalg conversion of
`quantize_per_tensor`, which rounds away from zero on ties.
This PR adds a conversion in the TorchOnnxToTorch pass for the ONNX
Multinomial operation. It also adds a TorchToLinalg lowering for the
`aten.Multinomial` op and does a light refactor of some repeated code
that generates random floating point numbers in
`TorchToLinalg/Random.cpp`.
This patch adds a few misc pad op related changes:
1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4
Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
The LpNormalization lowering was previously just computing the norm,
which is incorrect. This computes the norm then divides the input tensor
by it's norm.
I've tested this against some simple onnx models locally. I'll look into
adding a test case for this in an external test suite.
Register `aten.fake_quantize_per_channel_affine` and
`aten.fake_quantize_per_tensor_affine.tensor_qparams` ops
---------
Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
Fix the pad tensor rearrangement such that we change the representation
from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end,
...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the
dimensions of the pads tensor argument.
---------
Co-authored-by: zjgarvey <zjgarvey@gmail.com>
Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
* use lhs tensor's element type as compute type when rhs is scalar.
* previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 =
stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>)
-> tensor<2x5xi1>`
* now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT :
(tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>`
Addresses an issue with onnx.Gather lowering to linalg:
<https://github.com/nod-ai/SHARK-Turbine/issues/242>
The builder for tensor.expand_shape, without an explicitly provided
output shape, fails to infer an output shape in the case of multiple
dynamic reassociation dims. I tried adding the output shape explicitly
for tensor.expand_shape, but ran into compilation issues later on (see
<https://github.com/iree-org/iree/issues/17760>).
This PR adds support by lowering this op to tensor.reshape when multiple
dynamic reassociation dims are provided.
Due to the custom operation parser, the print and parser were expecting
two different forms.
One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
This bump triggered an upstream assert. Includes a WAR for #3506.
Also includes several things I needed to do to repro:
* When TORCH_MLIR_TEST_CONCURRENCY=1, test runs will be printed.
* Added TORCH_MLIR_TEST_VERBOSE=1 handling to enable verbose mode
(useful on CI).
---------
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
- Adds limited support for lowering onnx.Loop to primLoopOp
- lower in the pipeline`torch-to-scf` there is a check to see if loop is
for like. A primLoopOp is for like when the input condition is a
`trueBoolConstant`. To adapt the onnx to torch lowering to take
advantage of it, the implementation checks for specific op patterns in
the loodBody region and decides if loop is for like and uses the right
input condition op.
- to adapt the onnxLoopBody to torchLoopBody, we need to adapt the input
block arguments and set the correct output condition variable in the
loop body.
- scanOutput variables are currently not supported.
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.
We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.
For example
```
module {
func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
%arg1: !torch.vtensor<[336,168,3,3],f32>,
%arg2: !torch.vtensor<[336],f32>)
-> !torch.vtensor<[32,336,56,56],f32> {
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2
: !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
!torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
-> !torch.vtensor<[32,336,56,56],f32>
return %3 : !torch.vtensor<[32,336,56,56],f32>
}
}
```
would result in
```
[...]
%padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
%45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
The `index_put` operation, `input[indices] = values`, allows for the
values to be any shape that is broadcastable to the slice
`input[indices]`. This commit adds broadcasting support to the Linalg
lowering of `IndexPutHackedTwinOp`.
Fixes: #3465
This adds support for a few ops:
- torch.linalg_det
- torch._linalg_det (if the LU and pivot returns are unused)
- onnx.Det
An scf loop is used, since the row reduction algorithm applied here has
some loop-carried dependencies.
The current support being added here is very basic, and only works if no
permutations are required during row reduction, and assumes the matrices
are non-singular.
This adds a torchvision op to torch-mlir and a path from onnx.DeformConv
to torchvision.deform_conv2d.
I'm not implementing the torch->linalg lowering for the torchvision op
yet, but posting this PR to get feedback on some of the choices being
made here and to flesh out the onnx frontend a bit.
This adds an onnx->torch conversion for onnx.RoiAlign into
torchvision.roi_align or torchvision.roi_pool, and adds those two
torchvision ops to torch-mlir.
Add a new op with shape/dtypes and decompose into
`fake_quantize_per_tensor_affine` when the second result is unused.
The xfail_set change is on ONNX because torch cannot export this op to
ONNX.
1. truncates zero-points to i32
2. modifies the default accumulator type for i8 from i64 to i32.
3. now uses the input dtype to infer accumulator dtype.
This implements the Onnx.NegativeLogLikelihoodLoss op using the
signature provided
[here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html)
by replacing it with a `NLLLossForward` op.
Additionally, I included a helper function `get_loss_reduction_enum` to
convert from a string `reduction` parameter to the corresponding
intended integer value since this is an operation that will be reused
for any loss function module. This differs from `get_reduction_enum` in
`TorchUpstream.cpp` which handles the `reduce` parameter from
`scatter_reduce` type operations.