mirror of https://github.com/llvm/torch-mlir
[onnx] Fix onnx.ThresholdedRelu crash (#3638)
Result type was not fetched causing a crash on constructionpull/3644/head
parent
5b19ab93dc
commit
3a599bec80
|
@ -2623,7 +2623,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value input;
|
||||
float alpha;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||
|
|
|
@ -3477,3 +3477,14 @@ func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[
|
|||
}
|
||||
return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_thresholdedrelu
|
||||
func.func @test_thresholdedrelu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} {
|
||||
// CHECK: %[[FP2:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK: %[[FP0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: torch.aten.threshold %arg0, %[[FP2]], %[[FP0]]
|
||||
%0 = torch.operator "onnx.ThresholdedRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue