mirror of https://github.com/llvm/torch-mlir
315 lines
11 KiB
C++
315 lines
11 KiB
C++
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
// Also available under a BSD-style license. See LICENSE.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||
|
|
||
|
#include <climits>
|
||
|
#include <cstddef>
|
||
|
#include <cstdint>
|
||
|
#include <iterator>
|
||
|
#include <numeric>
|
||
|
|
||
|
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||
|
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||
|
#include "llvm/Support/FormatVariadic.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace tosa {
|
||
|
|
||
|
// Common function for lowering reduce operations to TOSA ops.
|
||
|
template <typename T>
|
||
|
llvm::Optional<Value> convertReduceOpCommon(
|
||
|
PatternRewriter &rewriter, Operation *op, RankedTensorType output_type,
|
||
|
Value input_value, ElementsAttr axes_elems, bool keep_dims,
|
||
|
Type reduce_element_type, bool is_quantized, double input_scale,
|
||
|
int64_t input_zp, double output_scale, int64_t output_zp) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
ArrayRef<int64_t> input_shape = input_type.getShape();
|
||
|
ArrayRef<int64_t> output_shape = output_type.getShape();
|
||
|
auto input_rank = input_shape.size();
|
||
|
Value val = input_value;
|
||
|
|
||
|
if (axes_elems.getNumElements() == 0) {
|
||
|
// No axes means return the original tensor.
|
||
|
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
|
||
|
rewriter, op->getLoc(), output_type, val);
|
||
|
val = identity_op.getResult();
|
||
|
} else {
|
||
|
// Reduce along each axis
|
||
|
SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
|
||
|
|
||
|
if (is_quantized) {
|
||
|
val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < axes_elems.getNumElements(); i++) {
|
||
|
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
|
||
|
if (axis_val < 0)
|
||
|
axis_val += input_rank;
|
||
|
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
|
||
|
|
||
|
shape_vec[axis_val] = 1;
|
||
|
RankedTensorType reduce_type =
|
||
|
RankedTensorType::get(shape_vec, reduce_element_type);
|
||
|
|
||
|
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
|
||
|
val, axis_attr);
|
||
|
|
||
|
val = reduce_op.getResult();
|
||
|
}
|
||
|
|
||
|
if (is_quantized) {
|
||
|
RankedTensorType output_rescale_type =
|
||
|
RankedTensorType::get(shape_vec, output_type.getElementType());
|
||
|
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
|
||
|
0, output_zp, false, true);
|
||
|
}
|
||
|
|
||
|
// Optionally squeeze out the reduced axes.
|
||
|
if (!keep_dims) {
|
||
|
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
|
||
|
rewriter, op->getLoc(), output_type, val,
|
||
|
rewriter.getI64ArrayAttr(output_shape));
|
||
|
val = reshape_op.getResult();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return val;
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceAll to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceAllOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceAny to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceAnyOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceMin to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceMinOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceMax to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceMaxOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceProd to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
bool input_is_qtype =
|
||
|
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
bool output_is_qtype =
|
||
|
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
|
||
|
if (input_is_qtype || output_is_qtype) {
|
||
|
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
|
||
|
"be all floating-point.");
|
||
|
return llvm::None;
|
||
|
}
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceProdOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceSum to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
bool input_is_qtype =
|
||
|
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
bool output_is_qtype =
|
||
|
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
|
||
|
if (input_is_qtype != output_is_qtype) {
|
||
|
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||
|
"be all quantized or all floating-point.");
|
||
|
return llvm::None;
|
||
|
}
|
||
|
|
||
|
double input_scale = 1.0f;
|
||
|
double output_scale = 1.0f;
|
||
|
int64_t input_zp = 0;
|
||
|
int64_t output_zp = 0;
|
||
|
Type reduce_element_type = input_type.getElementType();
|
||
|
|
||
|
if (input_is_qtype) {
|
||
|
auto input_qtype =
|
||
|
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||
|
auto output_qtype =
|
||
|
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||
|
|
||
|
int32_t input_shift = 20;
|
||
|
|
||
|
input_scale =
|
||
|
static_cast<double>(1 << input_shift) * input_qtype.getScale();
|
||
|
output_scale =
|
||
|
1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
|
||
|
|
||
|
input_zp = input_qtype.getZeroPoint();
|
||
|
output_zp = output_qtype.getZeroPoint();
|
||
|
reduce_element_type = rewriter.getI32Type();
|
||
|
}
|
||
|
|
||
|
return convertReduceOpCommon<tosa::ReduceSumOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||
|
output_zp);
|
||
|
}
|
||
|
|
||
|
// Lowers ReduceMean to a sequence of TOSA ops.
|
||
|
llvm::Optional<Value>
|
||
|
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||
|
RankedTensorType output_type, Value input_value,
|
||
|
ElementsAttr axes_elems, bool keep_dims) {
|
||
|
// reduce_mean is lowered as followed:
|
||
|
// op1 = reduce_sum(input)
|
||
|
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
|
||
|
|
||
|
RankedTensorType input_type =
|
||
|
input_value.getType().dyn_cast<RankedTensorType>();
|
||
|
if (!input_type)
|
||
|
return llvm::None;
|
||
|
|
||
|
bool input_is_qtype =
|
||
|
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
bool output_is_qtype =
|
||
|
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||
|
|
||
|
if (input_is_qtype != output_is_qtype) {
|
||
|
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||
|
"be all quantized or all floating-point.");
|
||
|
return llvm::None;
|
||
|
}
|
||
|
|
||
|
// Only supports float type mean() if it's non-quantized
|
||
|
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
|
||
|
op->emitWarning(
|
||
|
"Failed convertReduceMean: input unquantized type but output element "
|
||
|
"not FloatType!");
|
||
|
return llvm::None;
|
||
|
}
|
||
|
|
||
|
int64_t input_rank = input_type.getRank();
|
||
|
int64_t num_elems_on_reduced_axis = 1;
|
||
|
for (int i = 0; i < axes_elems.getNumElements(); i++) {
|
||
|
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
|
||
|
if (axis_val < 0)
|
||
|
axis_val += input_rank;
|
||
|
num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
|
||
|
}
|
||
|
double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
|
||
|
|
||
|
double input_scale = 1.0f;
|
||
|
double output_scale = 1.0f;
|
||
|
int64_t input_zp = 0;
|
||
|
int64_t output_zp = 0;
|
||
|
Type reduce_element_type = input_type.getElementType();
|
||
|
|
||
|
if (input_is_qtype) {
|
||
|
auto input_qtype =
|
||
|
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||
|
auto output_qtype =
|
||
|
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||
|
|
||
|
// Combine 'div_scale' as part of output rescale
|
||
|
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
|
||
|
|
||
|
input_zp = input_qtype.getZeroPoint();
|
||
|
output_zp = output_qtype.getZeroPoint();
|
||
|
reduce_element_type = rewriter.getI32Type();
|
||
|
}
|
||
|
|
||
|
auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
|
||
|
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||
|
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||
|
output_zp);
|
||
|
|
||
|
if (!val.hasValue())
|
||
|
return llvm::None;
|
||
|
|
||
|
if (!input_is_qtype) {
|
||
|
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
|
||
|
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
|
||
|
val.getValue(), div_const, 0)
|
||
|
.getResult();
|
||
|
}
|
||
|
|
||
|
return val;
|
||
|
}
|
||
|
|
||
|
} // namespace tosa
|
||
|
} // namespace mlir
|