Commit Graph

737 Commits (cc28d566ff02904bde4855f3e2c8a124ecb6f4d6)

Author SHA1 Message Date
Yuanqiang Liu 8814d0ae64
[Torch] emit aten.dot and canonicalize it to aten.matmul (#3361)
* canonicalize `aten.dot` to `aten.matmul`
2024-05-18 22:45:14 +08:00
Xinyu Yang 7faba75696
[Torch] Decompose AtenMaskedScatterOp (#3353)
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
2024-05-16 15:27:25 +08:00
Xinyu Yang a9edefb3cf
[Torch] Fix AtenSliceTensorOp::fold (#3345) 2024-05-16 11:42:43 +08:00
Peiming Liu ccb772cd0f
[sparse] propagate sparsity properly when decompose torch operations. (#3318) 2024-05-15 10:09:27 -07:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
Xinyu Yang 6b95dd461d
[Torch] Fix PrimNumToTensorScalarOp::fold (#3339)
In constant folding progress, a new constant op will be created
according to the origin op's result type.

See the code in TorchDialect.cpp.

```cpp
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
                                             Attribute value, Type type,
                                             Location loc) {
  if (auto integerType = dyn_cast<Torch::IntType>(type))
    return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));

  if (auto floatType = dyn_cast<Torch::FloatType>(type))
    return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));

  if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
    if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
      return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
    } else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
      return builder.create<Torch::ConstantNumberOp>(loc, intValue);
    }
  }

  if (isa<Torch::BoolType>(type)) {
    return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
  }

  if (isa<Torch::NoneType>(type))
    return builder.create<ConstantNoneOp>(loc);

  if (auto stringAttr = dyn_cast<StringAttr>(value))
    return builder.create<ConstantStrOp>(loc, stringAttr);

  if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
    // Only !torch.vtensor can be constant folded. !torch.tensor has
    // non-trivial aliasing semantics which prevent deduplicating it.
    assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
    return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
  }

  return nullptr;
}
```
So when the op has a tensor result type, it must be "ValueTensorType"
due to the **assert** statement. However, many fold methods in
TorchOps.cpp only have a judgment of "BaseTensorType".
2024-05-15 20:54:19 +08:00
zjgarvey 911e723581
Expands Q Commuting Ops (#3332)
After running the model tests in SHARK-TestSuite, I noticed a few model
failures due to half-fusion.

Notably, RDN_pytorch_vaiq_int8 had a depth=5 convolution chain with
multiple AtenViewOp's.
2024-05-13 11:01:53 -07:00
zjgarvey 75d1d72059
Generalize Operand Quantization in FuseQuantizeOps (#3327)
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
2024-05-12 20:49:59 -07:00
NeverRaR 1d4859699b
MaxPool1d lowering to linalg (#3295)
Co-authored-by: root <root@i32b01216.sqa.eu95>
2024-05-10 22:05:26 +05:30
penguin_wwy 64b59c7fc3
[FxImporter] Eliminate the dependency on the refinement pass (#3309) 2024-05-10 02:44:36 +08:00
aldesilv ec6d7aa5d2
OnnxToTorch lowering resize op (#3013)
https://github.com/nod-ai/SHARK-Turbine/issues/358
adds a lowering from onnx to linalg for bilinear and nearest resize with
support for using scales or sizes to get resize shape. uses coordinate
transform half pixel for bilinear mode and asymmetrical for nearest
mode. See
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added
two passes -- one for bilinear and the other for nearest.
2024-05-08 21:35:03 +00:00
Jiawei Wu 346a536c9f
[Torch Dialect] decompose all index_put-like op to aten.index_put.hacked_twin for stricter semantics (#3071)
This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.
2024-05-08 22:44:57 +08:00
Xinyu Yang abef114c0c
[torch] emit aten.Softshrink and aten.Hardshrink (#3248)
as title
2024-05-08 15:20:45 +08:00
Vivek Khandelwal e60160d793
Revert "Decompose AtenNonzeroOp" (#3289)
Reverts llvm/torch-mlir#3281
2024-05-06 09:52:04 -07:00
Xida Ren (Cedar) 1af00e6040
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
2024-05-05 21:59:25 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Xida Ren (Cedar) 315dc6c3e3
[torch] `aten.eye` should use dynamic dims when no static dims are available (#3202)
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-30 17:41:03 +00:00
Xinyu Yang f32ada993d
[Stablehlo] Improve the lowering of pool op in stablehlo (#3259)
1. Handle case stride == None
2. add avgpool3d maxpool1d  maxpool3d lowering
2024-05-01 00:06:13 +08:00
Vivek Khandelwal b1e2241479
[ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op (#3221)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-29 04:00:01 +00:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00
Xinyu Yang 5684dc0441
[Torch] emit aten.celu and decompose it (#3247)
CELU(x)=max(0,x)+min(0,α∗(exp(x/α)−1))
2024-04-28 17:23:40 +08:00
Yuanqiang Liu 46c0f3cad0
[Torch] emit aten.log_sigmoid and decompose it to log(sigmoid) (#3246) 2024-04-28 11:47:43 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
penguin_wwy 6679728c56
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code

https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
2024-04-27 14:00:56 -07:00
Yuanqiang Liu f173a06fa7
[Torch] emit aten.ne.str and add folder (#3242) 2024-04-28 00:58:50 +08:00
Yuanqiang Liu 634a796933
[Torch] fold aten.log (#3223) 2024-04-26 10:10:02 +08:00
Yuanqiang Liu b0ba3def93
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231) 2024-04-26 02:36:13 +08:00
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Xinyu Yang 4da3d714cc
[Torch] Support AtenProdOp on linalg and stablehlo (#3215) 2024-04-24 11:14:04 +08:00
zjgarvey a8ba865fca
[torch] Adds Quantization Support for `aten.relu` (#3177)
A choice was made to quantize the return type of Relu with a scale and
zero point copied from the input's quantization scheme. With this
choice, the torch-to-linalg conversion of quantized Relu essentially
computes max(input, zeroPoint) in the elementwise payload.
2024-04-23 11:01:36 -07:00
Yuanqiang Liu db3842f2e8
[Stablehlo] support lowering sinh & cosh to stablehlo (#3213) 2024-04-23 19:54:58 +08:00
penguin_wwy e5bdd71baf
[Torch] Emit and decompose prims.iota op (#3132) 2024-04-21 19:45:01 -07:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Xinyu Yang d4313eed4a
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
Xinyu Yang d2ba956e69
[Torch] Support Aten_CastLongOp. (#3160)
By canonicalize Aten_CastLongOp into AtenToDtypeOp
2024-04-17 21:58:32 +08:00
zjgarvey 5e564b5864
Adds Some Quantization Support for AtenMatmulOp (#3147)
1. onnx.MatMulInteger now converts to aten.matmul instead of aten.mm
2. aten.matmul, for ranks >=2, now allows quantized inputs and will
lower to linalg::quantized_matmul or linalg::quantized_batch_matmul.
3. added AtenMatmulOp to the FuseQuantizeOps rewrite patters
QuantizeOperands, QuantizeTransposedOperands, and QuantizeAccumulator
4. added several tests, including some to test AtenMmOp with varying
quantization signed-ness.
5. a quantized matmul mat-vec test is added to verify the failure to
lower to linalg; cleaned of out-of-date code related to common
torch-mlir lowering xfails.
6. in debugging a real model with quantized matmuls, I found a bug on
the scalarize-shapes pass which resulted from the aten.full op folder
returning an incompatible result type. This is fixed by the small change
here to
[lib/Dialect/Torch/IR/TorchOps.cpp](https://github.com/llvm/torch-mlir/compare/main...zjgarvey:torch-mlir:MatMulIntegerFix?expand=1#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4f).
2024-04-15 16:06:47 -07:00
IanWood1 5708ee7ec9
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
2024-04-15 13:45:10 -07:00
zjgarvey 197ef4224b
Avoid Type Mismatch in Slice Folder (#3154)
Fixes issue #3153
2024-04-12 11:43:45 -07:00
penguin_wwy d4a30b7e67
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)
We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
2024-04-11 06:47:35 -07:00
Xinyu Yang 308c45e61a
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
    def forward(self, x):
        if len(x.shape) == 5:
            b0, t, c0, h0, w0 = x.shape
            b, c, h, w = torch.mul(b0, t), c0, h0, w0
        else:
            b1, c1, h1, w1 = x.shape
            b, c, h, w = b1, c1, h1, w1
        res = torch.reshape(x, [b, c, h, w])
        return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
Xinyu Yang 6524838bcb
[Torch] Add general AdaptiveAvgPool2dOp decompose support (#3111)
Previously, it could only handle the situations where outputsize == (1,
1) or outputsize == (input_H, input_W). Now it supports all situations
where input_H % output_H== 0 && input_W % output_W == 0
2024-04-11 17:02:59 +08:00
Xinyu Yang 5eb0cf9104
[Torch] Add decompose of AtenToPrimDeviceOp (#3131)
As device information isn't relevant to torch-mlir
2024-04-10 22:26:48 +08:00
Xinyu Yang 42a16fa912
[Torch] Support Aten_CastFloatOp. (#3115)
By canonicalize Aten_CastFloatOp into AtenToDtypeOp
2024-04-09 11:06:53 +08:00
Xinyu Yang 84c24e5771
[Torch] Support Aten__And__ScalarOp (#3114) 2024-04-08 20:24:17 +08:00
Yuanqiang Liu 2c56ef9252
[Torch Dialect] canonicalize aten.sign to aten.sgn (#3112)
* `aten.sign` is a sub-set of `aten.sgn` (`aten.sgn` support complex
type).
2024-04-08 20:05:42 +08:00
Vivek Khandelwal 7e778e2179
build: manually update PyTorch version (#3094)
Set PyTorch and TorchVision version to nightly release 2024-04-01.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-03 10:48:37 +05:30
Rob Suderman f97cd4893f
[torch] Improve shape inference for dynamic shapes (#3091)
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
2024-04-02 16:19:57 -07:00
zjgarvey 40e762ca42
Adds result types to a prelu decomp (#3098)
This adds explicit result types instead of relying on shape/dtype
computations.

Solves a regression issue with IREE: #3092
2024-04-02 11:41:56 -07:00
zjgarvey 532d297c46
[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)
See the related issues here:
[SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556)

1. Adds uint8 casting to onnx.Cast op
2. Fixes an issue with onnx.DequantizeLinear when the scale comes with
shape [1].
3. Adds support for unsigned types in an AtenItemOp folder
4. Adds a simpler quantized model for easier debugging
5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm]
patterns to [transpose -> quant -> mm].
6. Moved some xfails that are still not passing, but for different
reasons than onnx.cast failures.
2024-04-01 16:21:05 -07:00
Xinyu Yang da88efad89
[Torch] Fix bug of DecomposeAtenSelectIntOp (#3087)
Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when
resultTy has not been inferred.

```
    auto resultTy = op.getType().cast<BaseTensorType>();
    if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
      rewriter.replaceOp(op, slice);
      return success();
    }

```

So I add restriction.
2024-04-01 21:25:02 +08:00