[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)

This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <https://github.com/iree-org/iree/issues/18677>
2. <https://github.com/iree-org/iree/issues/18631>
pull/3766/head
zjgarvey 2024-10-04 11:27:00 -05:00 committed by GitHub
parent f08bfc4ff8
commit 6e8c7bed4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 7 deletions

View File

@ -1627,6 +1627,25 @@ public:
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool isScalarOp = resultType.getShape().size() == 0;
if (isScalarOp) {
// for elementwise ops that are actually rank0 scalar computations,
// perform the payload outside a linalg generic op.
SmallVector<Value> payloadArgs;
for (auto t : tensorOperands) {
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
}
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
if (!scalarResult)
return rewriter.notifyMatchFailure(
op, "Failed to create payload for scalar elementwise op");
Value rank0Result =
createInitTensor(rewriter, loc, ValueRange{},
resultType.getElementType(), scalarResult);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
return success();
}
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),

View File

@ -4,13 +4,11 @@
// CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
// CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: }