diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1f21a1afe..ab4e284f8 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -516,13 +516,12 @@ public: if (!lhsTy) { return op.emitError("only Tensor types supported in StableHLO"); } + bool isRhsScalar = false; if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); - // use lhs's element type as compute type - rhs = - hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType()); rhsTy = dyn_cast(rhs.getType()); + isRhsScalar = true; } auto outType = cast( @@ -537,16 +536,28 @@ public: } if (isa(lhsElemTy) && isa(rhsElemTy)) { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_int, 1.1) use fp32 as compute type + // torch.lt(x_int, y_float) use y's float type as compute type + Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } else if (isa(lhsElemTy) && isa(rhsElemTy)) { + // always use lhs's float type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - if (lhsElemTy.getIntOrFloatBitWidth() > - rhsElemTy.getIntOrFloatBitWidth()) { + if (isRhsScalar) { + // torch.lt(x_float, 1.1) use x's float type as compute type + // torch.lt(x_int, 1) use x's int type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_float, y_float) use higher bitwidth as compute type + Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() > + rhsElemTy.getIntOrFloatBitWidth() + ? lhsElemTy + : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 918cbae63..c99ef4d96 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -528,7 +528,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AtenPolarFloatModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",