Fix onnx.If lowering with scalar condition tensor (#3846)

Fixes
https://github.com/nod-ai/SHARK-ModelDev/issues/696#issuecomment-2442016530
pull/3847/head
jinchen 2024-10-31 20:34:50 -07:00 committed by GitHub
parent 25738b8c19
commit 032a636c35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -180,7 +180,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto conditionType = auto conditionType =
cast<Torch::ValueTensorType>(conditionTensor.getType()); cast<Torch::ValueTensorType>(conditionTensor.getType());
if (!conditionType || conditionType.getSizes().size() != 1) if (!conditionType || conditionType.getSizes().size() > 1)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per " binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html"); "https://onnx.ai/onnx/operators/onnx__If.html");