yyp0
bdbc64a205
[TorchToStablehlo] support l1_loss, deg2rad, logit ( #3865 )
2024-11-18 11:25:00 +08:00
yyp0
7058f456b8
[Stablehlo] support aten.isfinite ( #3850 )
2024-11-07 16:52:39 +08:00
Jiawei Wu
b75d0e3f8b
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter ( #3829 )
...
In torch.index_put like ops, `values` is only required to be
broadcastable to `input[indices]`, rather than exact dimension match.
This patch fixes the problem by add additional
stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op.
BTW, this patch also enhance the `getBroadcastResultShape` utility in
hlo namespace.
2024-11-05 19:15:11 +08:00
Xinyu Yang
3dbeda9082
[Stablehlo] fix template typo ( #3842 )
...
I think we should use template parameters. @yyp0 @qingyunqu
2024-11-01 21:10:38 +08:00
yyp0
9ce2a69703
[Torch] support AtenExp2Op ( #3832 )
...
- support AtenExp2Op by decomposing it to aten.pow.scalar
- refine stablehlo pow.scalar pow.Tensor_Scalar pow.Tensor_Tensor
lowering according to https://github.com/llvm/torch-mlir/pull/2983
- Close https://github.com/llvm/torch-mlir/pull/2983
2024-10-31 19:14:05 +08:00
yyp0
d0041dc310
[stablehlo] support aten.view.dtype lowering ( #3778 )
2024-10-10 15:50:17 +08:00
Rob Suderman
2374b9e02d
Bump to llvm/llvm-project@e813750354 ( #3765 )
...
Includes stablehlo bump
2024-10-04 12:08:35 -07:00
Yuanqiang Liu
5f74de5ba0
[Stablehlo] support aten.all.dim ( #3746 )
2024-09-30 15:59:27 +08:00
yyp0
335cf5f6d0
[stablehlo] support aten_adaptive_max_pool1d lowering ( #3728 )
2024-09-26 11:42:38 +08:00
Yuanqiang Liu
7b94ced39a
[Stablehlo] fix aten compare ops' promote rules ( #3709 )
...
previous PR(https://github.com/llvm/torch-mlir/pull/3702 )
2024-09-13 18:48:41 +08:00
yyp0
43e3118eb9
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp ( #3592 )
2024-08-15 20:06:29 +08:00
Jiawei Wu
edc87fc577
[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops ( #3322 )
...
For now, at most one dynamic dim of index tensors in
aten.index/aten.index_put-like op is supported.
2024-08-01 10:41:09 +08:00
Jiawei Wu
7b2902f6e2
[stablehlo]: fix aten.index_put_hacked_twin lowering to StableHlo ( #3572 )
...
Current StableHlo lowering strategy works well when `src` tensor's rank
is no bigger than `dst` tensor's. The new patch make it succeed in other
cases. The following is an example.
```
%190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor>
%191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32>
```
2024-07-31 22:33:57 +08:00
Yuanqiang Liu
5bee9aac63
[Stablehlo] simplify promoteType ( #3525 )
...
only provide `outElementType` when promoteType
2024-07-10 10:52:19 +08:00
Yuanqiang Liu
3225f20ab1
[Stablehlo] use index type as dim size, avoid to generate index_cast ( #3526 )
...
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir )
2024-07-07 18:03:03 +08:00
Yuanqiang Liu
f1e3701caf
[Stablehlo] fix compareOp with scalar's lowering ( #3518 )
...
* use lhs tensor's element type as compute type when rhs is scalar.
* previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 =
stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>)
-> tensor<2x5xi1>`
* now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT :
(tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>`
2024-07-02 15:31:06 +08:00
Yuanqiang Liu
0e71a192d8
[Torch] support decomposition of aten.aminmax ( #3513 )
...
* unify decompisition of `aten.amax` and `aten.amin`
* support `aten.amax` with `dim=()`
2024-06-29 21:44:05 +08:00
Yuanqiang Liu
f9fc741eef
[Stablehlo] support aten.any.dim, aten.min.dim ( #3500 )
...
* refactor `TorchToStablehlo/Reduction.cpp`
* add `ConvertAtenReduceWithIndicesOp` patterns
2024-06-29 16:53:33 +08:00
Xinyu Yang
c7d52f63b4
[stablehlo] add aten::_int_mm lowering ( #3474 )
...
as title
2024-06-20 16:10:31 +08:00
Xinyu Yang
431d98b405
[Stablehlo] Add lowering of GridSampler Op ( #3084 )
...
Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
2024-06-07 16:06:07 +08:00
Yuanqiang Liu
50f7103098
[Stablehlo] support uint8 ( #3367 )
...
Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184 .
The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
%c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
%0 = tensor.empty() : tensor<3xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
^bb0(%in: i8, %out: f32):
%2 = arith.uitofp %in : i8 to f32
linalg.yield %2 : f32
} -> tensor<3xf32>
return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
%c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
%0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
2024-06-04 09:04:59 +08:00
Xinyu Yang
23b53050de
[Torch]Support conv_transpose1d and conv_transpose3d ( #3286 )
...
1. Support conv_transpose1d and conv_transpose3d
2. Fix bugs of convertTransposedConv func in
lib/Conversion/TorchToStablehlo/Linear.cpp
2024-06-03 15:11:12 +08:00
Rob Suderman
afca88a058
[NFC] Change to *cast instead of .*cast variants ( #3405 )
...
Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
2024-05-30 23:45:13 -07:00
penguin_wwy
1f544c37d0
[NFC] Remove unused header files ( #3386 )
2024-05-30 14:30:36 +08:00
Yuanqiang Liu
28aeb047c1
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic ( #3389 )
2024-05-26 12:34:56 +08:00
Yuanqiang Liu
5bb1a65ec9
[Stablehlo] refactor reduction lowering and support aten.amin ( #3383 )
...
* implement detailed lowering template pattern
`ConvertAtenReduceAllDimsOp` and `ConvertAtenReduceKeepDimOp`
* support `aten.amin`'s lowering.
2024-05-23 20:40:20 +08:00
Yuanqiang Liu
f4bfe3f948
Bump llvm and stablehlo ( #3377 )
...
* bump llvm to 1e5f29af81a5f6fda308074f6345b9fba4faa71c
* bump stablehlo to c44d9af8d4879adccf1054cb61a53377ae5898cb
2024-05-22 23:28:45 +08:00
Wu Yuan
cc28d566ff
[Stablehlo] Support AtenTrilOp ( #3359 )
...
1. lower aten.tril to stablehlo composed by iota, select and so forth
2. add related e2e test cases
2024-05-20 15:49:24 +08:00
Xinyu Yang
28193fd985
[Stablehlo]index type use i64 ( #3354 )
2024-05-16 15:33:23 +08:00
Yuanqiang Liu
5928f68e60
[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo ( #3348 )
...
* not to decompose `aten.amax` on `stablehlo` backend. Because it could
be lowering to `stablehlo.reduce` directly.
* lowering `aten.max.dim` to `stablehlo.reduce apply max` when
`AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.
2024-05-16 00:05:19 +08:00
Yuanqiang Liu
0b7cbf5e60
[Stablehlo] fix aten.randn's lowering with f32 element type ( #3329 )
2024-05-11 17:40:04 +08:00
Yuanqiang Liu
5f7cb9e253
[Stablehlo] lowering aten.randn & aten.normal_functional to mhlo.rng … ( #3328 )
...
…NORMAL
* split lowering of uniform, randn, normal from Basic.cpp into Rng.cpp
2024-05-11 15:33:37 +08:00
penguin_wwy
e0a87e543e
[NFC] Standardize the std::is_same competime expression ( #3321 )
2024-05-10 17:07:37 +08:00
penguin_wwy
afe87d62b4
[Linalg] [Stablehlo] Promote type for compare scalar op ( #3306 )
2024-05-10 02:20:06 +08:00
Yuanqiang Liu
5213557b87
[Stablehlo] fix lowering gelu(x, tanh) ( #3307 )
...
* lowering gelu("none") to erf
* lowering gelu("tanh") to tanh
2024-05-09 11:39:13 +08:00
Xinyu Yang
f32ada993d
[Stablehlo] Improve the lowering of pool op in stablehlo ( #3259 )
...
1. Handle case stride == None
2. add avgpool3d maxpool1d maxpool3d lowering
2024-05-01 00:06:13 +08:00
Xinyu Yang
0a5ff68d9d
[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo ( #3230 )
2024-04-29 17:40:30 +08:00
Stella Laurenzo
5d4b803914
[NFC reformat] Run pre-commit on all files and format misc.
...
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.
Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
penguin_wwy
6679728c56
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa ( #3243 )
...
Like #3130 , gradually replace the deprecated code
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
2024-04-27 14:00:56 -07:00
Xinyu Yang
ac85338491
[Stablehlo] Support AtenPowScalarOp, AtenTanOp, AtenAsinhOp, AtenAcoshOp, AtenAtanhOp, Atan2Op ( #3233 )
2024-04-26 15:47:44 +08:00
penguin_wwy
122eb69a98
[stablehlo] add aten left/right shift op conversion support ( #3234 )
2024-04-26 09:20:49 +08:00
Xinyu Yang
7030eacb76
[stablehlo] Support aten.any and aten.all lowering ( #3217 )
2024-04-25 11:15:52 +08:00
Xinyu Yang
e18bf42d0e
[stablehlo] Support ConstantPadNdOp in stablehlo ( #3211 )
...
as title
2024-04-24 14:15:11 +08:00
Xinyu Yang
42b9eccdb3
[Stablehlo] Fix AtenSumDimIntListOp when dim==None ( #3216 )
...
as titile
2024-04-24 11:25:46 +08:00
Xinyu Yang
4da3d714cc
[Torch] Support AtenProdOp on linalg and stablehlo ( #3215 )
2024-04-24 11:14:04 +08:00
Yuanqiang Liu
db3842f2e8
[Stablehlo] support lowering sinh & cosh to stablehlo ( #3213 )
2024-04-23 19:54:58 +08:00
Xinyu Yang
c1967b607f
[Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo ( #3208 )
2024-04-23 19:06:55 +08:00
Yuanqiang Liu
1f8123b5f0
[Stablehlo] support unary ops which promote to floating point ( #3209 )
...
* promote input to output element-type when lowering to stablehlo, so
that it could satisfy stablehlo's type constraints.
* split promote-to-fp unary ops from fp-only unary ops.
2024-04-23 17:57:12 +08:00
Yuanqiang Liu
797e4cd395
[Stablehlo] lowering asin, acos, atan ( #3207 )
...
* lowering asin, acos and atan to chlo ops.
2024-04-23 16:24:53 +08:00
penguin_wwy
a60e84e5ee
[stablehlo] add aten.expm1 op conversion support ( #3199 )
2024-04-21 19:20:49 -07:00