[stablehlo] Reduce unnecessary template specialization code (#3047)

pull/3093/head
penguin_wwy 2024-04-02 05:18:49 +08:00 committed by GitHub
parent 826786bdd0
commit b98f7f75dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 26 additions and 69 deletions

View File

@ -61,76 +61,18 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
}
auto const_type =
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Template specialization for APInt
template <>
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op, ArrayRef<APInt> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto const_type = RankedTensorType::get(
RankedTensorType const_type;
if constexpr (std::is_same_v<T, APInt>) {
const_type = RankedTensorType::get(
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
} else if constexpr (std::is_same_v<T, float>) {
const_type = RankedTensorType::get(shape, rewriter.getF32Type());
} else if constexpr (std::is_same_v<T, double>) {
const_type = RankedTensorType::get(shape, rewriter.getF64Type());
} else {
const_type =
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
}
// Template specialization for float
template <>
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op, ArrayRef<float> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
template <>
std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
Operation *op, ArrayRef<double> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
@ -139,6 +81,21 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
}
// Template instantiation
template std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op,
ArrayRef<APInt> vec,
ArrayRef<int64_t> shape);
template std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op,
ArrayRef<float> vec,
ArrayRef<int64_t> shape);
template std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
Operation *op,
ArrayRef<double> vec,
ArrayRef<int64_t> shape);
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,
ArrayRef<int32_t> vec,