From 0bb62e4347d239018797b0829b44cdbffa78a3a2 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 2 May 2024 21:30:24 +0530 Subject: [PATCH] Revert Onnx.Selu lowering to corresponding Aten op (#3275) --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 28 ++----------------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 +++-------- 2 files changed, 7 insertions(+), 37 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 5f9da3faa..3553d22c7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Torch::ValueTensorType inputType = - operand.getType().cast(); - Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value cstOne = rewriter.create( + Value vInputScale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - Value cstNone = rewriter.create(binder.getLoc()); - Value zeroTensor = rewriter.create( - binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, - cstNone, cstNone); - Value exp = rewriter.create(binder.getLoc(), - resultType, operand); - Value expMulAlpha = rewriter.create( - binder.getLoc(), resultType, exp, vAlpha); - Value expMulAlphaSubAlpha = rewriter.create( - binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); - Value neg = rewriter.create( - binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); - Value pos = rewriter.create( - binder.getLoc(), resultType, operand, vScale); - Type compareType = inputType.getWithSizesAndDtype( - inputType.getOptionalSizes(), rewriter.getI1Type()); - Value xLessThanZero = rewriter.create( - binder.getLoc(), compareType, operand, zeroTensor); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, xLessThanZero, neg, pos); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 2748a640a..0c2b9180c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,18 +582,10 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 - // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 - // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> }