[Stablehlo] fix aten compare ops' promote rules (#3709)

previous PR(https://github.com/llvm/torch-mlir/pull/3702)
pull/3713/head
Yuanqiang Liu 2024-09-13 18:48:41 +08:00 committed by GitHub
parent d61986cfcf
commit 7b94ced39a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 8 deletions

View File

@ -516,13 +516,12 @@ public:
if (!lhsTy) { if (!lhsTy) {
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
} }
bool isRhsScalar = false;
if (!rhsTy) { if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
rhs.getType()); rhs.getType());
// use lhs's element type as compute type
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
rhsTy = dyn_cast<RankedTensorType>(rhs.getType()); rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
isRhsScalar = true;
} }
auto outType = cast<RankedTensorType>( auto outType = cast<RankedTensorType>(
@ -537,16 +536,28 @@ public:
} }
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) { if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); // torch.lt(x_int, 1.1) use fp32 as compute type
// torch.lt(x_int, y_float) use y's float type as compute type
Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
} else if (isa<mlir::FloatType>(lhsElemTy) && } else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) { isa<mlir::IntegerType>(rhsElemTy)) {
// always use lhs's float type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else { } else {
if (lhsElemTy.getIntOrFloatBitWidth() > if (isRhsScalar) {
rhsElemTy.getIntOrFloatBitWidth()) { // torch.lt(x_float, 1.1) use x's float type as compute type
// torch.lt(x_int, 1) use x's int type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else { } else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); // torch.lt(x_float, y_float) use higher bitwidth as compute type
Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()
? lhsElemTy
: rhsElemTy;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
} }
} }
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType(); lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();

View File

@ -528,7 +528,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AtenPolarFloatModule_basic", "AtenPolarFloatModule_basic",
"DiagonalWithStaticShapeModule_basic", "DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic", "EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",