[NFC] Standardize the std::is_same competime expression (#3321)

pull/3323/head
penguin_wwy 2024-05-10 17:07:37 +08:00 committed by GitHub
parent 2c22087cab
commit e0a87e543e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 26 deletions

View File

@ -449,11 +449,13 @@ public:
"only floating-point or integer datatype legalization supported");
}
if (std::is_same<AtenOpT, AtenSquareOp>()) {
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
} else {
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
@ -462,8 +464,8 @@ public:
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
if constexpr (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
rewriter.replaceOp(op, result);
return success();
}
@ -575,32 +577,32 @@ public:
op->getContext(), chlo::ComparisonType::SIGNED);
}
if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LT);
} else if (std::is_same<AtenOpT, AtenGtTensorOp>() ||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenGtTensorOp>() ||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GT);
} else if (std::is_same<AtenOpT, AtenGeTensorOp>() ||
std::is_same<AtenOpT, AtenGeScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenGeTensorOp>() ||
std::is_same<AtenOpT, AtenGeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GE);
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenEqTensorOp>() ||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::EQ);
} else if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::NE);
} else if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LT);
} else if (std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>()) {
} else if constexpr (std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LE);
} else {

View File

@ -371,8 +371,8 @@ public:
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// There is no Lesser operator in TOSA.
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>());
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>());
// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
@ -388,12 +388,15 @@ public:
(swapLhsRhs ? lhs : rhsTensor));
// There is no NE operator in TOSA.
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>())
if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy,
resultOp.getResult());
else
}
else {
rewriter.replaceOp(op, resultOp.getResult());
}
return success();
}
@ -425,7 +428,7 @@ public:
op, "Only floating-point or integer datatype legalization supported");
Value rhsTensor;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
rhsTensor = lhs;
} else {
Value rhsAsTensor;