diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 9d06bed30..3332aafc1 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -37,6 +37,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +// To create INT48 TOSA constant, need to pass in llvm::APInt instead. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index e26f697c6..c7569b5e3 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -63,5 +63,83 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = + RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for APInt +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get( + shape, rewriter.getIntegerType(vec[0].getBitWidth())); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for float +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template instantiation +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); + } // namespace tosa } // namespace mlir