mirror of https://github.com/llvm/torch-mlir
[stablehlo] Reduce unnecessary template specialization code (#3047)
parent
826786bdd0
commit
b98f7f75dc
|
@ -61,76 +61,18 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto const_type =
|
||||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Template specialization for APInt
|
||||
template <>
|
||||
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<APInt> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
RankedTensorType const_type;
|
||||
if constexpr (std::is_same_v<T, APInt>) {
|
||||
const_type = RankedTensorType::get(
|
||||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||
} else {
|
||||
const_type =
|
||||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return std::nullopt;
|
||||
}
|
||||
auto const_type = RankedTensorType::get(
|
||||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Template specialization for float
|
||||
template <>
|
||||
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<float> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
template <>
|
||||
std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<double> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
|
@ -139,6 +81,21 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
// Template instantiation
|
||||
template std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
ArrayRef<APInt> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
template std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
ArrayRef<float> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
template std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
ArrayRef<double> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
ArrayRef<int32_t> vec,
|
||||
|
|
Loading…
Reference in New Issue