Revert Onnx.Selu lowering to corresponding Aten op (#3275)

pull/3281/head
Vivek Khandelwal 2024-05-02 21:30:24 +05:30 committed by GitHub
parent 11cd7cd9e7
commit 0bb62e4347
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 37 deletions

View File

@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorResultType(resultType))
return failure();
Torch::ValueTensorType inputType =
operand.getType().cast<Torch::ValueTensorType>();
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
cstNone, cstNone);
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
resultType, operand);
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, exp, vAlpha);
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, operand, vScale);
Type compareType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), rewriter.getI1Type());
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
binder.getLoc(), compareType, operand, zeroTensor);
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
binder.op, resultType, xLessThanZero, neg, pos);
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success();
});
patterns.onOp("ReduceL1", 1,

View File

@ -582,18 +582,10 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
// CHECK-LABEL: func.func @test_selu
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
// CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00
// CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00
// CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
// CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
// CHECK-DAG: %[[F3:.+]] = torch.constant.float 3
// CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]]
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}