[LINALG][MLIR] Fix the broadcast dim check for elementwise ops lowering

Signed-Off-by: Gaurav Shukla <gauravshukla789@gmail.com>
pull/2882/head
Gaurav Shukla 2024-02-07 03:18:28 +05:30
parent cc06391630
commit b3dfea223d
1 changed files with 5 additions and 1 deletions

View File

@ -297,7 +297,11 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
auto equalToOne = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
c1, currentDimSize);
auto isValidBroadcast =
b.create<arith::OrIOp>(loc, equalToRunning, equalToOne);
b.create<cf::AssertOp>(loc, isValidBroadcast,
"mismatched size for broadcast");
}
}