mirror of https://github.com/llvm/torch-mlir
3225f20ab1
For example, the original IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32> %0 = arith.index_cast %dim : index to i64 %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32> %1 = arith.index_cast %dim_0 : index to i64 %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32> %2 = arith.index_cast %dim_1 : index to i64 %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64> %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> return %4 : tensor<?x?x?xf32> } } ``` After using IndexType, the IR is: ``` module attributes {torch.debug_module_name = "Matmul3D"} { func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32> %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32> %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32> %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex> %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> return %1 : tensor<?x?x?xf32> } } ``` The benefits of using IndexType on shape tensor: * simplify the IR, avoid to generate `arith.index_cast` * let backend compiler have a chance to decide the index width of shape tensor * let stablehlo backend have a chance to serialize dynamic shape IR by [shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir) |
||
---|---|---|
.. | ||
basic.mlir | ||
elementwise.mlir | ||
gather.mlir | ||
linear.mlir | ||
lit.local.cfg | ||
pooling.mlir | ||
rng.mlir | ||
scatter.mlir | ||
view_like.mlir |