mirror of https://github.com/llvm/torch-mlir
[NFC] Standardize the std::is_same competime expression (#3321)
parent
2c22087cab
commit
e0a87e543e
|
@ -449,11 +449,13 @@ public:
|
|||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhs = lhs;
|
||||
} else if (!rhsType) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
} else {
|
||||
if (!rhsType) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
}
|
||||
}
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
|
@ -462,8 +464,8 @@ public:
|
|||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
||||
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
||||
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
||||
if constexpr (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
||||
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
@ -575,32 +577,32 @@ public:
|
|||
op->getContext(), chlo::ComparisonType::SIGNED);
|
||||
}
|
||||
|
||||
if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||
if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::LT);
|
||||
} else if (std::is_same<AtenOpT, AtenGtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenGtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::GT);
|
||||
} else if (std::is_same<AtenOpT, AtenGeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenGeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::GE);
|
||||
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::EQ);
|
||||
} else if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::NE);
|
||||
} else if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::LT);
|
||||
} else if (std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLeScalarOp>()) {
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLeScalarOp>()) {
|
||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), chlo::ComparisonDirection::LE);
|
||||
} else {
|
||||
|
|
|
@ -371,8 +371,8 @@ public:
|
|||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
// There is no Lesser operator in TOSA.
|
||||
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
|
||||
// Promote lhs and rhs dtypes for bitwise operators.
|
||||
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
|
@ -388,12 +388,15 @@ public:
|
|||
(swapLhsRhs ? lhs : rhsTensor));
|
||||
|
||||
// There is no NE operator in TOSA.
|
||||
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenNeScalarOp>())
|
||||
if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy,
|
||||
resultOp.getResult());
|
||||
else
|
||||
}
|
||||
|
||||
else {
|
||||
rewriter.replaceOp(op, resultOp.getResult());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -425,7 +428,7 @@ public:
|
|||
op, "Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value rhsTensor;
|
||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhsTensor = lhs;
|
||||
} else {
|
||||
Value rhsAsTensor;
|
||||
|
|
Loading…
Reference in New Issue