mirror of https://github.com/llvm/torch-mlir
[onnx] Fix default `alpha` for `onnx.Elu` (#3583)
We were defaulting to `0.0` for `onnx.Elu` when it is supposed to be `1.0`.pull/3587/head
parent
3d33c5a206
commit
d273bdfabf
|
@ -2340,7 +2340,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value input;
|
||||
float alpha;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.f32FloatAttr(alpha, "alpha") ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||
|
|
|
@ -1579,7 +1579,7 @@ func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>,
|
|||
|
||||
// CHECK-LABEL: @test_elu_default
|
||||
func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.elu %arg0, %float1.000000e00, %float1.000000e00_0, %float1.000000e00_0 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue