[NFC] Standardize the std::is_same competime expression (#3321)

pull/3323/head
penguin_wwy 2024-05-10 17:07:37 +08:00 committed by GitHub
parent 2c22087cab
commit e0a87e543e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 26 deletions

View File

@ -449,11 +449,13 @@ public:
"only floating-point or integer datatype legalization supported"); "only floating-point or integer datatype legalization supported");
} }
if (std::is_same<AtenOpT, AtenSquareOp>()) { if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs; rhs = lhs;
} else if (!rhsType) { } else {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), if (!rhsType) {
outElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
} }
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
@ -462,8 +464,8 @@ public:
Value result = Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions); rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() && if constexpr (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) { !std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
} }
@ -575,32 +577,32 @@ public:
op->getContext(), chlo::ComparisonType::SIGNED); op->getContext(), chlo::ComparisonType::SIGNED);
} }
if (std::is_same<AtenOpT, AtenLtTensorOp>() || if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) { std::is_same<AtenOpT, AtenLtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LT); op->getContext(), chlo::ComparisonDirection::LT);
} else if (std::is_same<AtenOpT, AtenGtTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenGtTensorOp>() ||
std::is_same<AtenOpT, AtenGtScalarOp>()) { std::is_same<AtenOpT, AtenGtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GT); op->getContext(), chlo::ComparisonDirection::GT);
} else if (std::is_same<AtenOpT, AtenGeTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenGeTensorOp>() ||
std::is_same<AtenOpT, AtenGeScalarOp>()) { std::is_same<AtenOpT, AtenGeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::GE); op->getContext(), chlo::ComparisonDirection::GE);
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenEqTensorOp>() ||
std::is_same<AtenOpT, AtenEqScalarOp>()) { std::is_same<AtenOpT, AtenEqScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::EQ); op->getContext(), chlo::ComparisonDirection::EQ);
} else if (std::is_same<AtenOpT, AtenNeTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>()) { std::is_same<AtenOpT, AtenNeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::NE); op->getContext(), chlo::ComparisonDirection::NE);
} else if (std::is_same<AtenOpT, AtenLtTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()) { std::is_same<AtenOpT, AtenLtScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LT); op->getContext(), chlo::ComparisonDirection::LT);
} else if (std::is_same<AtenOpT, AtenLeTensorOp>() || } else if constexpr (std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>()) { std::is_same<AtenOpT, AtenLeScalarOp>()) {
compareDirectionAttr = chlo::ComparisonDirectionAttr::get( compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
op->getContext(), chlo::ComparisonDirection::LE); op->getContext(), chlo::ComparisonDirection::LE);
} else { } else {

View File

@ -371,8 +371,8 @@ public:
} }
auto rhsTensor = rhsTy ? rhs : rhsAsTensor; auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// There is no Lesser operator in TOSA. // There is no Lesser operator in TOSA.
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() || constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>()); std::is_same<AtenOpT, AtenLtScalarOp>());
// Promote lhs and rhs dtypes for bitwise operators. // Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
@ -388,12 +388,15 @@ public:
(swapLhsRhs ? lhs : rhsTensor)); (swapLhsRhs ? lhs : rhsTensor));
// There is no NE operator in TOSA. // There is no NE operator in TOSA.
if (std::is_same<AtenOpT, AtenNeTensorOp>() || if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>()) std::is_same<AtenOpT, AtenNeScalarOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy, rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy,
resultOp.getResult()); resultOp.getResult());
else }
else {
rewriter.replaceOp(op, resultOp.getResult()); rewriter.replaceOp(op, resultOp.getResult());
}
return success(); return success();
} }
@ -425,7 +428,7 @@ public:
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
Value rhsTensor; Value rhsTensor;
if (std::is_same<AtenOpT, AtenSquareOp>()) { if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
rhsTensor = lhs; rhsTensor = lhs;
} else { } else {
Value rhsAsTensor; Value rhsAsTensor;