From 3a599bec80c0f77d72984c88166bd558fad43f21 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Aug 2024 09:23:38 -0700 Subject: [PATCH] [onnx] Fix onnx.ThresholdedRelu crash (#3638) Result type was not fetched causing a crash on construction --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 3 ++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index dcb6e6763..68868e95c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2623,7 +2623,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value input; float alpha; if (binder.tensorOperand(input) || - binder.f32FloatAttr(alpha, "alpha", 1.0)) { + binder.f32FloatAttr(alpha, "alpha", 1.0) || + binder.tensorResultType(resultType)) { return failure(); } Value cstAlpha = rewriter.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 3c37cc9c5..be14dccd4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -3477,3 +3477,14 @@ func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[ } return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32> } + +// ----- + +// CHECK-LABEL: @test_thresholdedrelu +func.func @test_thresholdedrelu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[FP2:.+]] = torch.constant.float 2.000000e+00 + // CHECK: %[[FP0:.+]] = torch.constant.float 0.000000e+00 + // CHECK: torch.aten.threshold %arg0, %[[FP2]], %[[FP0]] + %0 = torch.operator "onnx.ThresholdedRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}