[tosa] Add more common utility functions (#525)

- Common code as TF repository, being moved to MLIR core.
- Will support further legalizations to be published.

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/527/head snapshot-20220114.206
Suraj Sudhir 2022-01-14 13:57:27 -08:00 committed by GitHub
parent 5ded7d096f
commit edf4a0e729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 0 deletions

View File

@ -37,6 +37,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
template <typename T>
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>

View File

@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
template <typename T>
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None;
}
auto const_type =
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Template specialization for APInt
template <>
llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op, ArrayRef<APInt> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None;
}
auto const_type = RankedTensorType::get(
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Template specialization for float
template <>
llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op, ArrayRef<float> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None;
}
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Template instantiation
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,
ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir