mirror of https://github.com/llvm/torch-mlir
[tosa] Add more common utility functions (#525)
- Common code as TF repository, being moved to MLIR core. - Will support further legalizations to be published. Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/527/head snapshot-20220114.206
parent
5ded7d096f
commit
edf4a0e729
|
@ -37,6 +37,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
|||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
|
||||
template <typename T>
|
||||
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
|
||||
// Creates a TOSA operation and performs shape inference on the individual
|
||||
// op. This allows shape inference during the framework to TOSA lowering.
|
||||
template <typename TosaOp, typename... Args>
|
||||
|
|
|
@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
|||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
template <typename T>
|
||||
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
auto const_type =
|
||||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Template specialization for APInt
|
||||
template <>
|
||||
llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<APInt> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
auto const_type = RankedTensorType::get(
|
||||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Template specialization for float
|
||||
template <>
|
||||
llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<float> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
}
|
||||
|
||||
if (vec.size() != num_total_elements) {
|
||||
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Template instantiation
|
||||
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
ArrayRef<int32_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue