From b3dfea223d2d991c9e76f0c85efbca90a2ee9092 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Wed, 7 Feb 2024 03:18:28 +0530 Subject: [PATCH] [LINALG][MLIR] Fix the broadcast dim check for elementwise ops lowering Signed-Off-by: Gaurav Shukla --- lib/Conversion/TorchToLinalg/Utils.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 20b32cd1f..b4d5a4564 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -297,7 +297,11 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( auto equalToRunning = b.create(loc, arith::CmpIPredicate::eq, resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, + auto equalToOne = b.create(loc, arith::CmpIPredicate::eq, + c1, currentDimSize); + auto isValidBroadcast = + b.create(loc, equalToRunning, equalToOne); + b.create(loc, isValidBroadcast, "mismatched size for broadcast"); } }