Commit Graph

14 Commits (18139994e807d262f52a13b2c8e1b3edfa45ffa0)

Author SHA1 Message Date
Jiawei Wu edc87fc577
[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322)
For now, at most one dynamic dim of index tensors in
aten.index/aten.index_put-like op is supported.
2024-08-01 10:41:09 +08:00
Yuanqiang Liu 5bee9aac63
[Stablehlo] simplify promoteType (#3525)
only provide `outElementType` when promoteType
2024-07-10 10:52:19 +08:00
Yuanqiang Liu 3225f20ab1
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %0 = arith.index_cast %dim : index to i64
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %1 = arith.index_cast %dim_0 : index to i64
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %2 = arith.index_cast %dim_1 : index to i64
    %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
    %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %4 : tensor<?x?x?xf32>
  }
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
    %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
    %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %1 : tensor<?x?x?xf32>
  }
}
```

The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
penguin_wwy 1f544c37d0
[NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
Yuanqiang Liu 28aeb047c1
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389) 2024-05-26 12:34:56 +08:00
Xinyu Yang 0a5ff68d9d
[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230) 2024-04-29 17:40:30 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
penguin_wwy 6679728c56
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code

https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
2024-04-27 14:00:56 -07:00
penguin_wwy b98f7f75dc
[stablehlo] Reduce unnecessary template specialization code (#3047) 2024-04-01 14:18:49 -07:00
Sambhav Jain 8a17c98b74
Bump stablehlo to openxla/stablehlo@fd52182f76 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
    rewriter.startRootUpdate(op);
    ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
      rewriter.finalizeRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
      rewriter.cancelRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
    rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
    ~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```

I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.

It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test 

...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference                                                                               
  %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>             
       ^                                                                                                                                                                                            
LLVM ERROR: Failed to infer result type(s).               
```

Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-01-31 14:21:17 -08:00
Stella Laurenzo 8252656b6d
Advance llvm-project and stablehlo. (#2619)
llvm-project: bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df
stablehlo: ab709fe48de88c67717abfbd7ef17425eb95ddaf

These commits were chosen in order to account for an MLIR API break from
3dbac2c007
which required a patch to stablehlo. We integrate a bit beyond that
commit to deal with some revert/reapply cycles in the intervening range
which were discovered in another downstream.

Further, it requires adaptation to the stablehlo API breaks introduced
from https://github.com/openxla/stablehlo/pull/1872 which are along for
the ride.

Since some stablehlo builders were changed to directly take int64_t
array refs, also traced that up some call stacks to eliminate some
signed/unsigned mismatches that result.

Also adds a few TOSA tests to the passing set that seem to work now.
2023-12-07 23:13:42 -08:00
Yuanqiang Liu 0548e2ef3b
[Stablehlo] fix promoteType() when input doesn't have DefiningOp (#2262) 2023-06-26 00:04:17 +08:00
Ziheng Jiang 72bb902640
[STABLEHLO] Move utils.h to include/ (#1974) 2023-03-27 21:16:21 -07:00
Ashay Rane 711646d095
mhlo: migrate conversion to stablehlo (#1840)
This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
2023-02-02 07:29:47 -06:00