mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix compareOp with scalar's lowering (#3518)
* use lhs tensor's element type as compute type when rhs is scalar. * previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>) -> tensor<2x5xi1>` * now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>`pull/3521/head
parent
e2fbded49c
commit
f1e3701caf
|
@ -517,6 +517,8 @@ public:
|
|||
if (!rhsTy) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
rhs.getType());
|
||||
// use lhs's element type as compute type
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
||||
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue