mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix aten compare ops' promote rules (#3709)
previous PR(https://github.com/llvm/torch-mlir/pull/3702)pull/3713/head
parent
d61986cfcf
commit
7b94ced39a
|
@ -516,13 +516,12 @@ public:
|
|||
if (!lhsTy) {
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
}
|
||||
bool isRhsScalar = false;
|
||||
if (!rhsTy) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
rhs.getType());
|
||||
// use lhs's element type as compute type
|
||||
rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
|
||||
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
isRhsScalar = true;
|
||||
}
|
||||
|
||||
auto outType = cast<RankedTensorType>(
|
||||
|
@ -537,16 +536,28 @@ public:
|
|||
}
|
||||
|
||||
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||
// torch.lt(x_int, 1.1) use fp32 as compute type
|
||||
// torch.lt(x_int, y_float) use y's float type as compute type
|
||||
Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
|
||||
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
||||
isa<mlir::IntegerType>(rhsElemTy)) {
|
||||
// always use lhs's float type as compute type
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||
} else {
|
||||
if (lhsElemTy.getIntOrFloatBitWidth() >
|
||||
rhsElemTy.getIntOrFloatBitWidth()) {
|
||||
if (isRhsScalar) {
|
||||
// torch.lt(x_float, 1.1) use x's float type as compute type
|
||||
// torch.lt(x_int, 1) use x's int type as compute type
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||
} else {
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||
// torch.lt(x_float, y_float) use higher bitwidth as compute type
|
||||
Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() >
|
||||
rhsElemTy.getIntOrFloatBitWidth()
|
||||
? lhsElemTy
|
||||
: rhsElemTy;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
|
||||
}
|
||||
}
|
||||
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||
|
|
|
@ -528,7 +528,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AtenPolarFloatModule_basic",
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||
|
|
Loading…
Reference in New Issue