Commit Graph

25 Commits (main)

Author SHA1 Message Date
Jiawei Wu b75d0e3f8b
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829)
In torch.index_put like ops, `values` is only required to be
broadcastable to `input[indices]`, rather than exact dimension match.
This patch fixes the problem by add additional
stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op.
BTW, this patch also enhance the `getBroadcastResultShape` utility in
hlo namespace.
2024-11-05 19:15:11 +08:00
yyp0 43e3118eb9
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592) 2024-08-15 20:06:29 +08:00
Jiawei Wu edc87fc577
[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322)
For now, at most one dynamic dim of index tensors in
aten.index/aten.index_put-like op is supported.
2024-08-01 10:41:09 +08:00
Jiawei Wu 7b2902f6e2
[stablehlo]: fix aten.index_put_hacked_twin lowering to StableHlo (#3572)
Current StableHlo lowering strategy works well when `src` tensor's rank
is no bigger than `dst` tensor's. The new patch make it succeed in other
cases. The following is an example.
```
%190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor>
%191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32>
```
2024-07-31 22:33:57 +08:00
Yuanqiang Liu 3225f20ab1
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %0 = arith.index_cast %dim : index to i64
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %1 = arith.index_cast %dim_0 : index to i64
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %2 = arith.index_cast %dim_1 : index to i64
    %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
    %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %4 : tensor<?x?x?xf32>
  }
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
    %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
    %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %1 : tensor<?x?x?xf32>
  }
}
```

The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
Xinyu Yang 431d98b405
[Stablehlo] Add lowering of GridSampler Op (#3084)
Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
2024-06-07 16:06:07 +08:00
Yuanqiang Liu 50f7103098
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
2024-06-04 09:04:59 +08:00
Rob Suderman afca88a058
[NFC] Change to *cast instead of .*cast variants (#3405)
Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
2024-05-30 23:45:13 -07:00
penguin_wwy 1f544c37d0
[NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
Yuanqiang Liu 28aeb047c1
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389) 2024-05-26 12:34:56 +08:00
Yuanqiang Liu f4bfe3f948
Bump llvm and stablehlo (#3377)
* bump llvm to 1e5f29af81a5f6fda308074f6345b9fba4faa71c
* bump stablehlo to c44d9af8d4879adccf1054cb61a53377ae5898cb
2024-05-22 23:28:45 +08:00
Xinyu Yang 28193fd985
[Stablehlo]index type use i64 (#3354) 2024-05-16 15:33:23 +08:00
penguin_wwy 6679728c56
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code

https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
2024-04-27 14:00:56 -07:00
penguin_wwy d4a30b7e67
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)
We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
2024-04-11 06:47:35 -07:00
Jiawei Wu 76080936d4
[stablehlo] add aten.index_put and aten.scatter_add op conversion support (#3086) 2024-04-01 19:39:49 +08:00
Sambhav Jain 8a17c98b74
Bump stablehlo to openxla/stablehlo@fd52182f76 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
    rewriter.startRootUpdate(op);
    ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
      rewriter.finalizeRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
      rewriter.cancelRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
    rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
    ~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```

I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.

It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test 

...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference                                                                               
  %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>             
       ^                                                                                                                                                                                            
LLVM ERROR: Failed to infer result type(s).               
```

Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-01-31 14:21:17 -08:00
Quinn Dawkins 494089d53d
Clang format refresh (#2812)
After noticing a number of commits with unrelated formatting changes,
I think something was changed with clang-format at one point and we're
seeing a number of unrelated changes. Doing a refresh can help avoid
this.

The changes made here came from
```
find lib -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find include -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find projects -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
```
2024-01-29 12:59:33 -05:00
Vivek Khandelwal 3841fe3035 [MLIR][TORCH] Add StableHLO lowering for embedding_bag.padding_idx op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-09-05 21:32:23 +05:30
Ramiro Leal-Cavazos 41bafe13cc
[build] Update llvm tag to a3f2751f (#2397)
This commit updates the `llvm-project` and `mlir-hlo` submodules to
commits:

llvm-project: a3f2751f782f3cdc6ba4790488ec20163a40ac37
mlir-hlo: 97c7e4b4506c3a2441c923e592833f45da439009

Changes made:

- Rename `getSuccessorEntryOperands` with `getEntrySuccessorOperands`
and remove `operands` from
`getSuccessorRegions` (https://reviews.llvm.org/D157506)
- Make `TypeConverter` a `const` (https://reviews.llvm.org/D157601)
2023-08-15 09:53:28 -07:00
Jiawei Wu 60bad54f27
[Torch Dialect] replace none-index in aten.Index.Tensor's param by manually generating it (#2344)
* [Torch Dialect] replace none-index in aten.Index.Tensor's  param by manually generating it
Co-authored-by: Jiawei Wu <wujiawei.aml@bytedance.com>
Co-authored-by: Jianzhe Xiao <jianzhe.xiao@bytedance.com>

* minor typo fix

* add new failed e2e tests for ltc

* fix typo

* Address comments

* Add more e2e tests

* add failed e2e tests for LTC

* address comments

* remove decomposition for AtenIndexTensorHackedTwinOp
2023-08-15 19:36:08 +08:00
Jiawei Wu 026e8db2e4
[Stablehlo] add converter for aten.scatter.src op (#2295) 2023-07-24 10:14:45 +08:00
Zhekun Zhang eb8f56aeb7
[Stablehlo] Add `AtenIndexTensor` StableHlo support (#2107)
* Add AtenIndexTensor StableHlo support

* clean up

* Empty commit, trigger test

* try to debug hanging test

* fix segfulat

* fix bad include

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
2023-05-24 11:13:57 -07:00
Abhishek Varma 318fe13468 [MLIR][TORCH] Patch up Ops and their lowerings to deal with +ve `dim`
-- In Python we have the concept of negative dimension indexing.
-- We would want to normalize such dimensions to be +ve and within the
   expected range instead.
-- This commit takes care of a few remaining set of Ops and their
   lowerings by applying `toPositiveDim` and `isValidDim` to the
   extracted integer `dim` value.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-14 13:12:56 +05:30
Ziheng Jiang 72bb902640
[STABLEHLO] Move utils.h to include/ (#1974) 2023-03-27 21:16:21 -07:00
Zhekun Zhang 5758a0bfbb
[StableHLO] Support for slice_scatter (#1960)
Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
2023-03-22 13:41:04 -07:00