mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix aten compare ops' promote rules (#3709)
previous PR(https://github.com/llvm/torch-mlir/pull/3702)pull/3713/head
parent
d61986cfcf
commit
7b94ced39a
|
@ -516,13 +516,12 @@ public:
|
||||||
if (!lhsTy) {
|
if (!lhsTy) {
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
|
bool isRhsScalar = false;
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
rhs.getType());
|
rhs.getType());
|
||||||
// use lhs's element type as compute type
|
|
||||||
rhs =
|
|
||||||
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
|
|
||||||
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
|
isRhsScalar = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outType = cast<RankedTensorType>(
|
auto outType = cast<RankedTensorType>(
|
||||||
|
@ -537,16 +536,28 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
// torch.lt(x_int, 1.1) use fp32 as compute type
|
||||||
|
// torch.lt(x_int, y_float) use y's float type as compute type
|
||||||
|
Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy;
|
||||||
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
|
||||||
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
|
||||||
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
||||||
isa<mlir::IntegerType>(rhsElemTy)) {
|
isa<mlir::IntegerType>(rhsElemTy)) {
|
||||||
|
// always use lhs's float type as compute type
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||||
} else {
|
} else {
|
||||||
if (lhsElemTy.getIntOrFloatBitWidth() >
|
if (isRhsScalar) {
|
||||||
rhsElemTy.getIntOrFloatBitWidth()) {
|
// torch.lt(x_float, 1.1) use x's float type as compute type
|
||||||
|
// torch.lt(x_int, 1) use x's int type as compute type
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||||
} else {
|
} else {
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
// torch.lt(x_float, y_float) use higher bitwidth as compute type
|
||||||
|
Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() >
|
||||||
|
rhsElemTy.getIntOrFloatBitWidth()
|
||||||
|
? lhsElemTy
|
||||||
|
: rhsElemTy;
|
||||||
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
|
||||||
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||||
|
|
|
@ -528,7 +528,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"AtenPolarFloatModule_basic",
|
"AtenPolarFloatModule_basic",
|
||||||
"DiagonalWithStaticShapeModule_basic",
|
"DiagonalWithStaticShapeModule_basic",
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||||
|
|
Loading…
Reference in New Issue