From dce6c83f96febb5dfca5fc21e8e623f58fd0fc6a Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sun, 27 Aug 2023 21:53:37 -0400 Subject: [PATCH] Allow lower bounds check in broadcast simplification --- .../TMTensor/Transforms/ConvertBroadcastToLinalg.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertBroadcastToLinalg.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertBroadcastToLinalg.cpp index dc9e3e45c..5097867cb 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertBroadcastToLinalg.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertBroadcastToLinalg.cpp @@ -53,6 +53,14 @@ public: broadcastedStatus.push_back(false); continue; } + // If the dim is non-unit, then it is an assert + non-broadcasted dim. + FailureOr lb = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::LB, input, i); + if (succeeded(lb) && *lb > 1) { + // TODO: Insert appropriate assert here. + broadcastedStatus.push_back(false); + continue; + } FailureOr isUnit = ValueBoundsConstraintSet::areEqual(input, oneIndex, i, std::nullopt); if (succeeded(isUnit) && *isUnit) {