[onnx] Fix default `alpha` for `onnx.Elu` (#3583)

We were defaulting to `0.0` for `onnx.Elu` when it is supposed to be
`1.0`.
pull/3587/head
Rob Suderman 2024-08-02 09:29:17 -07:00 committed by GitHub
parent 3d33c5a206
commit d273bdfabf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -2340,7 +2340,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value input;
float alpha;
if (binder.tensorOperand(input) ||
binder.f32FloatAttr(alpha, "alpha") ||
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
binder.tensorResultType(resultType))
return failure();
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(

View File

@ -1579,7 +1579,7 @@ func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>,
// CHECK-LABEL: @test_elu_default
func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.elu %arg0, %float1.000000e00, %float1.000000e00_0, %float1.000000e00_0 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}