2023-07-24 10:14:45 +08:00
|
|
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @forward(
|
|
|
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_1:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_2:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
2024-06-19 07:59:53 +08:00
|
|
|
// CHECK-DAG: %[[VAR_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
|
|
|
// CHECK-DAG: %[[VAR_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
|
|
|
// CHECK-DAG: %[[VAR_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
2023-07-24 10:14:45 +08:00
|
|
|
// CHECK: %int0 = torch.constant.int 0
|
|
|
|
// CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index
|
|
|
|
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor<?x?xi64>
|
|
|
|
// CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index
|
|
|
|
// CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor<?x?xi64>
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
|
|
|
|
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
|
|
|
|
// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xindex>
|
|
|
|
// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xindex>
|
|
|
|
// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]] : tensor<2xindex>
|
|
|
|
// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor<?x?xi64>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xi64>
|
|
|
|
// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]], %[[CONSTANT_1]] : tensor<3xindex>
|
|
|
|
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xindex>) -> tensor<?x?x1xi64>
|
|
|
|
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xindex>) -> tensor<?x?x1xi64>
|
2023-07-24 10:14:45 +08:00
|
|
|
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
|
2024-05-05 19:56:12 +08:00
|
|
|
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
|
2023-07-24 10:14:45 +08:00
|
|
|
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
|
|
|
|
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
|
2024-05-05 19:56:12 +08:00
|
|
|
// CHECK: }) : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
2023-07-24 10:14:45 +08:00
|
|
|
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
|
|
|
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
|
|
|
|
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
|
|
|
%int0 = torch.constant.int 0
|
|
|
|
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
|
|
|
|
return %0 : !torch.vtensor<[?,?],si64>
|
2024-04-28 05:08:09 +08:00
|
|
|
}
|