[onnx] Lower onnx.HardSigmoid to torch (#2682)

The expression for HardSigmoid in Onnx
(https://onnx.ai/onnx/operators/onnx__HardSigmoid.html): max(0, min(1,
alpha * x + beta))

is inherently different from HardSigmoid in Torch
(https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html)
which is: if x < -3 -> 0
elif x > 3 -> 1
else x/6 + 1/2

That being said, it was just better to compute out the entire expression
when translating the Onnx expression to Torch mlir, which is done in
this PR. Some of the logic is shared from the files in
`DecomposeComplexOps`. Therefore, refactored some shared logic between
`DecomposeComplexOps` and `DefaultDomainGToP` and put it in a `Utils`
file.
pull/2706/head
John Wu 2023-12-21 07:29:22 -08:00 committed by GitHub
parent 779a141f8d
commit 46f2cb50dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 31 deletions

View File

@ -11,6 +11,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
namespace mlir {
@ -117,6 +118,17 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
Value opSize, Value opStride,
Location loc);
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
Value createInitTensor(PatternRewriter &rewriter, Location loc,
BaseTensorType resultType, Value scalar,
Value sizeList);
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
@ -27,7 +28,47 @@ using namespace mlir::torch::onnx_c;
// thing here, so we simplify.
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
OnnxCustomOpConversionPattern &patterns) {
patterns.onOp("HardSigmoid", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensorOperand;
float alpha, beta;
if (binder.tensorOperand(tensorOperand) ||
binder.f32FloatAttr(alpha, "alpha", 0.2) ||
binder.f32FloatAttr(beta, "beta", 0.5) ||
binder.tensorResultType(resultType))
return failure();
// HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta))
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(beta));
// Expression: alpha * x + beta
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha);
// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantOne);
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta);
// Expression: max(0, min(1, alpha * x + beta))
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantZero);
rewriter.replaceOpWithNewOp<Torch::AtenMaximumOp>(
binder.op, resultType, zeroTensor, minExpression);
return success();
});
patterns.onOp(
"Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value operand;

View File

@ -126,35 +126,6 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
return sub;
}
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
BaseTensorType resultType, Value scalar,
Value sizeList) {
assert(resultType.hasDtype() && "result must have dtype");
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
/*layout=*/noneVal,
/*device=*/noneVal,
/*memory_format=*/noneVal);
}
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
assert(inputType.hasDtype() && "input must have dtype");
SmallVector<int64_t> sizes;
BaseTensorType rank0TensorTy =
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
.cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{});
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
}
// Share code between `softmax_backward` and `log_softmax_backward` ops.
// Returns x - y * sum(z, dim).
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mlir/IR/BuiltinDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::torch;
@ -74,7 +75,6 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
}
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
Type Torch::getTypeForTorchType(
MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness) {
@ -471,3 +471,32 @@ LogicalResult Torch::checkDefaultStrideHelper(Operation *op,
return success();
}
}
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc,
BaseTensorType resultType, Value scalar,
Value sizeList) {
assert(resultType.hasDtype() && "result must have dtype");
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
/*layout=*/noneVal,
/*device=*/noneVal,
/*memory_format=*/noneVal);
}
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
assert(inputType.hasDtype() && "input must have dtype");
SmallVector<int64_t> sizes;
BaseTensorType rank0TensorTy =
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
.cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{});
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
}

View File

@ -198,4 +198,70 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
// CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
%0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
return %0 : !torch.vtensor<[3,4,5],i1>
}
}
// CHECK-LABEL: @test_hardsigmoid_example
func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// CHECK-LABEL: @test_hardsigmoid
func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// CHECK-LABEL: @test_hardsigmoid_default
func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}