From 46f2cb50dca5e789d1114b127d9a4312fbb8e3d9 Mon Sep 17 00:00:00 2001 From: John Wu Date: Thu, 21 Dec 2023 07:29:22 -0800 Subject: [PATCH] [onnx] Lower onnx.HardSigmoid to torch (#2682) The expression for HardSigmoid in Onnx (https://onnx.ai/onnx/operators/onnx__HardSigmoid.html): max(0, min(1, alpha * x + beta)) is inherently different from HardSigmoid in Torch (https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html) which is: if x < -3 -> 0 elif x > 3 -> 1 else x/6 + 1/2 That being said, it was just better to compute out the entire expression when translating the Onnx expression to Torch mlir, which is done in this PR. Some of the logic is shared from the files in `DecomposeComplexOps`. Therefore, refactored some shared logic between `DecomposeComplexOps` and `DefaultDomainGToP` and put it in a `Utils` file. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 12 ++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 41 +++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 29 -------- lib/Dialect/Torch/Utils/Utils.cpp | 31 ++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 68 ++++++++++++++++++- 5 files changed, 150 insertions(+), 31 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 0e4c2b0a0..25d35f0f9 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -11,6 +11,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" namespace mlir { @@ -117,6 +118,17 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, Value opSize, Value opStride, Location loc); +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, + Value sizeList); + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0191fcf36..b9bb6a540 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -27,7 +28,47 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("HardSigmoid", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensorOperand; + float alpha, beta; + if (binder.tensorOperand(tensorOperand) || + binder.f32FloatAttr(alpha, "alpha", 0.2) || + binder.f32FloatAttr(beta, "beta", 0.5) || + binder.tensorResultType(resultType)) + return failure(); + + // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + + // Expression: alpha * x + beta + Value alpha_x_plus_beta = rewriter.create( + binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); + + // Expression: min(1, alpha * x + beta) + Value constantOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantOne); + Value minExpression = rewriter.create( + binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); + + // Expression: max(0, min(1, alpha * x + beta)) + Value constantZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantZero); + rewriter.replaceOpWithNewOp( + binder.op, resultType, zeroTensor, minExpression); + return success(); + }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 8162d2bb6..d8b8639e0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -126,35 +126,6 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc, return sub; } -// Helper to create a tensor filled with the given scalar. Scalar would be -// converted the to the element type of the given tensor type. -static Value createInitTensor(PatternRewriter &rewriter, Location loc, - BaseTensorType resultType, Value scalar, - Value sizeList) { - assert(resultType.hasDtype() && "result must have dtype"); - Value noneVal = rewriter.create(loc); - Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); - return rewriter.create(loc, resultType, sizeList, scalar, dtype, - /*layout=*/noneVal, - /*device=*/noneVal, - /*memory_format=*/noneVal); -} - -// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` -// would be converted to the element type of the given `inputType`. -static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, - BaseTensorType inputType, Value scalar) { - assert(inputType.hasDtype() && "input must have dtype"); - SmallVector sizes; - BaseTensorType rank0TensorTy = - inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) - .cast(); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), - ValueRange{}); - return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); -} - // Share code between `softmax_backward` and `log_softmax_backward` ops. // Returns x - y * sum(z, dim). static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c5b0eec50..4bf5f7e13 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::torch; @@ -74,7 +75,6 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { } llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } - Type Torch::getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness) { @@ -471,3 +471,32 @@ LogicalResult Torch::checkDefaultStrideHelper(Operation *op, return success(); } } + +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, + Value sizeList) { + assert(resultType.hasDtype() && "result must have dtype"); + Value noneVal = rewriter.create(loc); + Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); + return rewriter.create(loc, resultType, sizeList, scalar, dtype, + /*layout=*/noneVal, + /*device=*/noneVal, + /*memory_format=*/noneVal); +} + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar) { + assert(inputType.hasDtype() && "input must have dtype"); + SmallVector sizes; + BaseTensorType rank0TensorTy = + inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) + .cast(); + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), + ValueRange{}); + return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 27e7f2c6a..08bb69f23 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -198,4 +198,70 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> return %0 : !torch.vtensor<[3,4,5],i1> -} \ No newline at end of file +} + +// CHECK-LABEL: @test_hardsigmoid_example +func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32> + + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_hardsigmoid +func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_hardsigmoid_default +func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}