[onnx] Fix onnx.ThresholdedRelu crash (#3638)

Result type was not fetched causing a crash on construction
pull/3644/head
Rob Suderman 2024-08-16 09:23:38 -07:00 committed by GitHub
parent 5b19ab93dc
commit 3a599bec80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -2623,7 +2623,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value input;
float alpha;
if (binder.tensorOperand(input) ||
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(

View File

@ -3477,3 +3477,14 @@ func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[
}
return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: @test_thresholdedrelu
func.func @test_thresholdedrelu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} {
// CHECK: %[[FP2:.+]] = torch.constant.float 2.000000e+00
// CHECK: %[[FP0:.+]] = torch.constant.float 0.000000e+00
// CHECK: torch.aten.threshold %arg0, %[[FP2]], %[[FP0]]
%0 = torch.operator "onnx.ThresholdedRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}