//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include #include #include #include #include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { // Common function for lowering reduce operations to TOSA ops. template llvm::Optional convertReduceOpCommon( PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims, Type reduce_element_type, bool is_quantized, double input_scale, int64_t input_zp, double output_scale, int64_t output_zp) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; ArrayRef input_shape = input_type.getShape(); ArrayRef output_shape = output_type.getShape(); auto input_rank = input_shape.size(); Value val = input_value; if (axes_elems.getNumElements() == 0) { // No axes means return the original tensor. auto identity_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val); val = identity_op.getResult(); } else { // Reduce along each axis SmallVector shape_vec(input_shape.begin(), input_shape.end()); if (is_quantized) { val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp); } for (int i = 0; i < axes_elems.getNumElements(); i++) { int64_t axis_val = axes_elems.getValues()[i].getInt(); if (axis_val < 0) axis_val += input_rank; auto axis_attr = rewriter.getI64IntegerAttr(axis_val); shape_vec[axis_val] = 1; RankedTensorType reduce_type = RankedTensorType::get(shape_vec, reduce_element_type); auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); val = reduce_op.getResult(); } if (is_quantized) { RankedTensorType output_rescale_type = RankedTensorType::get(shape_vec, output_type.getElementType()); val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, 0, output_zp, false, true); } // Optionally squeeze out the reduced axes. if (!keep_dims) { auto reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val, rewriter.getI64ArrayAttr(output_shape)); val = reshape_op.getResult(); } } return val; } // Lowers ReduceAll to a sequence of TOSA ops. llvm::Optional convertReduceAllOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } // Lowers ReduceAny to a sequence of TOSA ops. llvm::Optional convertReduceAnyOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } // Lowers ReduceMin to a sequence of TOSA ops. llvm::Optional convertReduceMinOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } // Lowers ReduceMax to a sequence of TOSA ops. llvm::Optional convertReduceMaxOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } // Lowers ReduceProd to a sequence of TOSA ops. llvm::Optional convertReduceProdOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; bool input_is_qtype = input_type.getElementType().isa(); bool output_is_qtype = output_type.getElementType().isa(); if (input_is_qtype || output_is_qtype) { op->emitOpError("ConvertReduceProdOp: input/output tensor should " "be all floating-point."); return llvm::None; } return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } // Lowers ReduceSum to a sequence of TOSA ops. llvm::Optional convertReduceSumOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; bool input_is_qtype = input_type.getElementType().isa(); bool output_is_qtype = output_type.getElementType().isa(); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " "be all quantized or all floating-point."); return llvm::None; } double input_scale = 1.0f; double output_scale = 1.0f; int64_t input_zp = 0; int64_t output_zp = 0; Type reduce_element_type = input_type.getElementType(); if (input_is_qtype) { auto input_qtype = input_type.getElementType().cast(); auto output_qtype = output_type.getElementType().cast(); int32_t input_shift = 20; input_scale = static_cast(1 << input_shift) * input_qtype.getScale(); output_scale = 1.0 / (output_qtype.getScale() * static_cast(1 << input_shift)); input_zp = input_qtype.getZeroPoint(); output_zp = output_qtype.getZeroPoint(); reduce_element_type = rewriter.getI32Type(); } return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale, output_zp); } // Lowers ReduceMean to a sequence of TOSA ops. llvm::Optional convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { // reduce_mean is lowered as followed: // op1 = reduce_sum(input) // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) RankedTensorType input_type = input_value.getType().dyn_cast(); if (!input_type) return llvm::None; bool input_is_qtype = input_type.getElementType().isa(); bool output_is_qtype = output_type.getElementType().isa(); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " "be all quantized or all floating-point."); return llvm::None; } // Only supports float type mean() if it's non-quantized if (!input_is_qtype && !output_type.getElementType().isa()) { op->emitWarning( "Failed convertReduceMean: input unquantized type but output element " "not FloatType!"); return llvm::None; } int64_t input_rank = input_type.getRank(); int64_t num_elems_on_reduced_axis = 1; for (int i = 0; i < axes_elems.getNumElements(); i++) { int64_t axis_val = axes_elems.getValues()[i].getInt(); if (axis_val < 0) axis_val += input_rank; num_elems_on_reduced_axis *= input_type.getShape()[axis_val]; } double div_scale = 1.0 / static_cast(num_elems_on_reduced_axis); double input_scale = 1.0f; double output_scale = 1.0f; int64_t input_zp = 0; int64_t output_zp = 0; Type reduce_element_type = input_type.getElementType(); if (input_is_qtype) { auto input_qtype = input_type.getElementType().cast(); auto output_qtype = output_type.getElementType().cast(); // Combine 'div_scale' as part of output rescale output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale(); input_zp = input_qtype.getZeroPoint(); output_zp = output_qtype.getZeroPoint(); reduce_element_type = rewriter.getI32Type(); } auto val = convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale, output_zp); if (!val.has_value()) return llvm::None; if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, val.value(), div_const, 0) .getResult(); } return val; } } // namespace tosa } // namespace mlir