torch-mlir/test
zjgarvey 6e8c7bed4b
[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)
This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <https://github.com/iree-org/iree/issues/18677>
2. <https://github.com/iree-org/iree/issues/18631>
2024-10-04 11:27:00 -05:00
..
CAPI [NFC reformat] Run pre-commit on all files and format misc. 2024-04-27 14:08:09 -07:00
Conversion [TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762) 2024-10-04 11:27:00 -05:00
Dialect [torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751) 2024-10-02 05:55:54 -07:00
RefBackend Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
python [FxImporter] Fix sympy_int_to_int utility (#3657) 2024-08-26 09:31:17 -07:00
CMakeLists.txt [NFC reformat] Run pre-commit on all files and format misc. 2024-04-27 14:08:09 -07:00
lit.cfg.py [NFC reformat] Applies pre-commit formatting to Python files. (#3244) 2024-04-27 14:16:31 -07:00
lit.site.cfg.py.in Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00