From 6e8c7bed4b12117764274e79bc60a93443d5bdd5 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:27:00 -0500 Subject: [PATCH] [TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762) This is motivated by the fact that shapes are stored as tensors in ONNX, and IREE tries to perform tensor arithmetic on the device. This causes unnecessary dispatches, and makes it harder for the compiler to reason about shapes. Here is a small snippet of torch-IR that is typical seen coming from ONNX models: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64> %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1> %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64> %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64> %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> return %11 : !torch.vtensor<[],si64> } } ``` Without the change in this PR, the result would be: ```mlir #map = affine_map<() -> ()> module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor func.func @main_graph(%arg0: tensor, %arg1: tensor) -> tensor { %c0_i64 = arith.constant 0 : i64 %c0 = arith.constant 0 : index %dim = tensor.dim %arg1, %c0 : tensor %0 = arith.index_cast %dim : index to i64 %1 = tensor.empty() : tensor<1xi64> %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor) -> tensor %extracted = tensor.extract %2[] : tensor %3 = arith.cmpi eq, %extracted, %c0_i64 : i64 %dim_0 = tensor.dim %arg0, %c0 : tensor %4 = arith.index_cast %dim_0 : index to i64 %5 = tensor.empty() : tensor %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor) -> tensor %7 = tensor.empty() : tensor %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor) -> tensor %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor) -> tensor %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor, tensor, tensor) outs(%7 : tensor) { ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64): %11 = arith.select %in, %in_1, %in_2 : i64 linalg.yield %11 : i64 } -> tensor return %10 : tensor } } ``` With the change in this PR, we would instead get: ```mlir module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor func.func @main_graph(%arg0: tensor, %arg1: tensor) -> tensor { %c0_i64 = arith.constant 0 : i64 %c0 = arith.constant 0 : index %dim = tensor.dim %arg1, %c0 : tensor %0 = arith.index_cast %dim : index to i64 %1 = tensor.empty() : tensor<1xi64> %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor) -> tensor %extracted = tensor.extract %2[] : tensor %3 = arith.cmpi eq, %extracted, %c0_i64 : i64 %dim_0 = tensor.dim %arg0, %c0 : tensor %4 = arith.index_cast %dim_0 : index to i64 %5 = arith.select %3, %4, %extracted : i64 %6 = tensor.empty() : tensor %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor) -> tensor return %7 : tensor } } ``` Some related issues for context: 1. 2. --- .../TorchToLinalg/Uncategorized.cpp | 19 +++++++++++++++++++ .../Conversion/TorchToLinalg/elementwise.mlir | 12 +++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0f6f92bd7..0532b4b19 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1627,6 +1627,25 @@ public: operands, [](Value v) { return isa(v.getType()); })); auto resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); + bool isScalarOp = resultType.getShape().size() == 0; + if (isScalarOp) { + // for elementwise ops that are actually rank0 scalar computations, + // perform the payload outside a linalg generic op. + SmallVector payloadArgs; + for (auto t : tensorOperands) { + payloadArgs.push_back(rewriter.create(loc, t)); + } + Value scalarResult = createLinalgPayloadCalculationForElementwiseOp( + rewriter, loc, getTypeConverter(), payloadArgs, op, operands); + if (!scalarResult) + return rewriter.notifyMatchFailure( + op, "Failed to create payload for scalar elementwise op"); + Value rank0Result = + createInitTensor(rewriter, loc, ValueRange{}, + resultType.getElementType(), scalarResult); + rewriter.replaceOpWithNewOp(op, resultType, rank0Result); + return success(); + } bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5..ecf4caa58 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -4,13 +4,11 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor -// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { -// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): -// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32 -// CHECK: linalg.yield %[[TANH]] : f32 -// CHECK: } -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor +// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: }