From f1e3701cafe827e242cddc11124b9b222c716e3c Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 2 Jul 2024 15:31:06 +0800 Subject: [PATCH] [Stablehlo] fix compareOp with scalar's lowering (#3518) * use lhs tensor's element type as compute type when rhs is scalar. * previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>) -> tensor<2x5xi1>` * now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>` --- lib/Conversion/TorchToStablehlo/Basic.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 4d7597902..644d28cc0 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -517,6 +517,8 @@ public: if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); + // use lhs's element type as compute type + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhsTy = dyn_cast(rhs.getType()); }