mirror of https://github.com/llvm/torch-mlir
[tosa] Add more common utility functions (#525)
- Common code as TF repository, being moved to MLIR core. - Will support further legalizations to be published. Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/527/head snapshot-20220114.206
parent
5ded7d096f
commit
edf4a0e729
|
@ -37,6 +37,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
||||||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
float val);
|
float val);
|
||||||
|
|
||||||
|
// Templated function to create a constant op for given type and shape.
|
||||||
|
// T: storage C type.
|
||||||
|
// Default template creates a constant tensor in T.
|
||||||
|
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
|
||||||
|
template <typename T>
|
||||||
|
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
// Creates a TOSA operation and performs shape inference on the individual
|
// Creates a TOSA operation and performs shape inference on the individual
|
||||||
// op. This allows shape inference during the framework to TOSA lowering.
|
// op. This allows shape inference during the framework to TOSA lowering.
|
||||||
template <typename TosaOp, typename... Args>
|
template <typename TosaOp, typename... Args>
|
||||||
|
|
|
@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Templated function to create a constant op for given type and shape.
|
||||||
|
// T: storage C type.
|
||||||
|
// Default template creates a constant tensor in T.
|
||||||
|
template <typename T>
|
||||||
|
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type =
|
||||||
|
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template specialization for APInt
|
||||||
|
template <>
|
||||||
|
llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, ArrayRef<APInt> vec,
|
||||||
|
ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type = RankedTensorType::get(
|
||||||
|
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template specialization for float
|
||||||
|
template <>
|
||||||
|
llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, ArrayRef<float> vec,
|
||||||
|
ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template instantiation
|
||||||
|
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||||
|
Operation *,
|
||||||
|
ArrayRef<int32_t> vec,
|
||||||
|
ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
} // namespace tosa
|
} // namespace tosa
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue