mirror of https://github.com/llvm/torch-mlir
[NFC] Standardize the std::is_same competime expression (#3321)
parent
2c22087cab
commit
e0a87e543e
|
@ -449,11 +449,13 @@ public:
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||||
rhs = lhs;
|
rhs = lhs;
|
||||||
} else if (!rhsType) {
|
} else {
|
||||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
if (!rhsType) {
|
||||||
outElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
|
outElemTy);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||||
|
@ -462,8 +464,8 @@ public:
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
|
||||||
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
if constexpr (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
||||||
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -575,32 +577,32 @@ public:
|
||||||
op->getContext(), chlo::ComparisonType::SIGNED);
|
op->getContext(), chlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::LT);
|
op->getContext(), chlo::ComparisonDirection::LT);
|
||||||
} else if (std::is_same<AtenOpT, AtenGtTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenGtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::GT);
|
op->getContext(), chlo::ComparisonDirection::GT);
|
||||||
} else if (std::is_same<AtenOpT, AtenGeTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenGeTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::GE);
|
op->getContext(), chlo::ComparisonDirection::GE);
|
||||||
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::EQ);
|
op->getContext(), chlo::ComparisonDirection::EQ);
|
||||||
} else if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::NE);
|
op->getContext(), chlo::ComparisonDirection::NE);
|
||||||
} else if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::LT);
|
op->getContext(), chlo::ComparisonDirection::LT);
|
||||||
} else if (std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
} else if constexpr (std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLeScalarOp>()) {
|
std::is_same<AtenOpT, AtenLeScalarOp>()) {
|
||||||
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
||||||
op->getContext(), chlo::ComparisonDirection::LE);
|
op->getContext(), chlo::ComparisonDirection::LE);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -371,8 +371,8 @@ public:
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||||
// There is no Lesser operator in TOSA.
|
// There is no Lesser operator in TOSA.
|
||||||
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||||
|
|
||||||
// Promote lhs and rhs dtypes for bitwise operators.
|
// Promote lhs and rhs dtypes for bitwise operators.
|
||||||
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
@ -388,12 +388,15 @@ public:
|
||||||
(swapLhsRhs ? lhs : rhsTensor));
|
(swapLhsRhs ? lhs : rhsTensor));
|
||||||
|
|
||||||
// There is no NE operator in TOSA.
|
// There is no NE operator in TOSA.
|
||||||
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenNeScalarOp>())
|
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy,
|
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultTy,
|
||||||
resultOp.getResult());
|
resultOp.getResult());
|
||||||
else
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
rewriter.replaceOp(op, resultOp.getResult());
|
rewriter.replaceOp(op, resultOp.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -425,7 +428,7 @@ public:
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
Value rhsTensor;
|
Value rhsTensor;
|
||||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||||
rhsTensor = lhs;
|
rhsTensor = lhs;
|
||||||
} else {
|
} else {
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
|
|
Loading…
Reference in New Issue