[tosa] Add Torch reduction operators

- Supports variants with multiple dims, one dim, all dime
- Leverages legalize_common and legalize_utils code from
TensorFlow-TOSA work

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/461/head
Suraj Sudhir 2021-12-02 16:52:01 -08:00 committed by Sean Silva
parent ab6211184f
commit c9c9b68d1f
8 changed files with 854 additions and 8 deletions

View File

@ -0,0 +1,64 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace tosa {
// Lowers ReduceAll to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceAny to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceMin to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceMax to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceProd to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceSum to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
// Lowers ReduceMean to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);
} // namespace tosa
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H

View File

@ -0,0 +1,96 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace tosa {
// Create a TOSA rescale op from input framework scaling, zero points and
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, bool double_round,
bool scale32);
// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale,
int64_t input_zp);
// Create a 32-bit float constant operator from a float
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
Args &&... args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getRegions(), returnedShapes)
.failed())
return op;
// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}
// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType();
result.setType(new_ty);
return op;
}
template <typename TosaOp, typename... Args>
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
Type result_ty, Args &&... args) {
auto result =
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
rewriter.replaceOp(op, result->getResults());
}
} // namespace tosa
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

View File

@ -1,5 +1,7 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchToTosa.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa

View File

@ -8,6 +8,8 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@ -252,6 +254,156 @@ LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
return success();
}
using ReductionConvFunc = llvm::Optional<Value> (*)(PatternRewriter &,
Operation *,
RankedTensorType, Value,
ElementsAttr, bool);
// They all constitute a common form invoking the appropriate
// converion function in TosaLegalizeCommon.cpp
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
// Each variant must implement corresponding parameter parsing options
virtual LogicalResult readReduceDimsAndKeepDims(
AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr, bool &keepDims) const {
return rewriter.notifyMatchFailure(
op, "Unimplemented reduce_dims and keep_dims parsing function");
}
// Common rewriter for all reduction ops, calls the specific implementation of
// readReduceDimsAndKeepDims() needed for the op variant.
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy)
return op.emitError("Only Tensor types supported in TOSA");
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
if (!outputTy)
return op.emitError(
"Only ranked tensor type outputs permitted for reduce_mean");
ElementsAttr reduceDimsAttr;
bool keepDims;
if (failed(readReduceDimsAndKeepDims(op, adaptor, rewriter, reduceDimsAttr,
keepDims)))
return failure();
llvm::Optional<Value> result =
ConversionFuncT(rewriter, op, outputTy, self, reduceDimsAttr, keepDims);
if (!result)
return failure();
// TBD - support dtype casting.
rewriter.replaceOp(op, {result.getValue()});
return success();
}
};
// This reduction op legalization template handles op variants that have
// explicit reduce_dims dimensions (provided as a list) and keep_dims
// parameters.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenMultipleDimsReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
SmallVector<int64_t, 4> reduceDims;
if (!matchPattern(op.dim(), m_TorchConstantIntList(reduceDims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
keepDims = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");
return success();
}
};
// This reduction op legalization template handles op variants that reduce in
// only one explicit dim which is provided as a number (rather than a list), and
// a keep_dims parameter.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenOneDimReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
int64_t reduceDim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef({reduceDim}));
keepDims = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");
return success();
}
};
// This reduction op legalization template handles op variants that reduce all
// dims does not keep dims.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenAllDimsReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
public:
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
// Select all dims to reduce
SmallVector<int64_t, 4> reduceDims;
for (int64_t i = 0; i < selfTy.getRank(); i++)
reduceDims.push_back(i);
int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
keepDims = false;
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
@ -300,6 +452,36 @@ public:
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
mlir::tosa::convertReduceMeanOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
mlir::tosa::convertReduceSumOp)
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);

View File

@ -0,0 +1,314 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include <climits>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
// Common function for lowering reduce operations to TOSA ops.
template <typename T>
llvm::Optional<Value> convertReduceOpCommon(
PatternRewriter &rewriter, Operation *op, RankedTensorType output_type,
Value input_value, ElementsAttr axes_elems, bool keep_dims,
Type reduce_element_type, bool is_quantized, double input_scale,
int64_t input_zp, double output_scale, int64_t output_zp) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
ArrayRef<int64_t> input_shape = input_type.getShape();
ArrayRef<int64_t> output_shape = output_type.getShape();
auto input_rank = input_shape.size();
Value val = input_value;
if (axes_elems.getNumElements() == 0) {
// No axes means return the original tensor.
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
rewriter, op->getLoc(), output_type, val);
val = identity_op.getResult();
} else {
// Reduce along each axis
SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
if (is_quantized) {
val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
}
for (int i = 0; i < axes_elems.getNumElements(); i++) {
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
shape_vec[axis_val] = 1;
RankedTensorType reduce_type =
RankedTensorType::get(shape_vec, reduce_element_type);
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
val = reduce_op.getResult();
}
if (is_quantized) {
RankedTensorType output_rescale_type =
RankedTensorType::get(shape_vec, output_type.getElementType());
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
0, output_zp, false, true);
}
// Optionally squeeze out the reduced axes.
if (!keep_dims) {
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), output_type, val,
rewriter.getI64ArrayAttr(output_shape));
val = reshape_op.getResult();
}
}
return val;
}
// Lowers ReduceAll to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
return convertReduceOpCommon<tosa::ReduceAllOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceAny to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
return convertReduceOpCommon<tosa::ReduceAnyOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceMin to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
return convertReduceOpCommon<tosa::ReduceMinOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceMax to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
return convertReduceOpCommon<tosa::ReduceMaxOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceProd to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype || output_is_qtype) {
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
"be all floating-point.");
return llvm::None;
}
return convertReduceOpCommon<tosa::ReduceProdOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceSum to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point.");
return llvm::None;
}
double input_scale = 1.0f;
double output_scale = 1.0f;
int64_t input_zp = 0;
int64_t output_zp = 0;
Type reduce_element_type = input_type.getElementType();
if (input_is_qtype) {
auto input_qtype =
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
auto output_qtype =
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
int32_t input_shift = 20;
input_scale =
static_cast<double>(1 << input_shift) * input_qtype.getScale();
output_scale =
1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
input_zp = input_qtype.getZeroPoint();
output_zp = output_qtype.getZeroPoint();
reduce_element_type = rewriter.getI32Type();
}
return convertReduceOpCommon<tosa::ReduceSumOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
output_zp);
}
// Lowers ReduceMean to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
// reduce_mean is lowered as followed:
// op1 = reduce_sum(input)
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return llvm::None;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point.");
return llvm::None;
}
// Only supports float type mean() if it's non-quantized
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element "
"not FloatType!");
return llvm::None;
}
int64_t input_rank = input_type.getRank();
int64_t num_elems_on_reduced_axis = 1;
for (int i = 0; i < axes_elems.getNumElements(); i++) {
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
}
double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
double input_scale = 1.0f;
double output_scale = 1.0f;
int64_t input_zp = 0;
int64_t output_zp = 0;
Type reduce_element_type = input_type.getElementType();
if (input_is_qtype) {
auto input_qtype =
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
auto output_qtype =
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
// Combine 'div_scale' as part of output rescale
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
input_zp = input_qtype.getZeroPoint();
output_zp = output_qtype.getZeroPoint();
reduce_element_type = rewriter.getI32Type();
}
auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
output_zp);
if (!val.hasValue())
return llvm::None;
if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.getValue(), div_const, 0)
.getResult();
}
return val;
}
} // namespace tosa
} // namespace mlir

View File

@ -0,0 +1,67 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
namespace mlir {
namespace tosa {
// Create a TOSA rescale op from input framework tensor, zero points and
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, bool double_round,
bool scale32) {
int32_t multiplier;
int32_t shift;
int32_t scale_width = scale32 ? 32 : 16;
computeMultiplierAndShift(scale, multiplier, shift, scale_width);
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, input_val,
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
rewriter.getBoolAttr(false));
return rescale_op.getResult();
}
// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale,
int64_t input_zp) {
// Output is always int32 type
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
assert(input_type);
auto output_type = input_type.clone(rewriter.getI32Type());
return buildRescale(rewriter, op, output_type, input_val, input_scale,
input_zp, 0, false, true);
}
// Create a 32-bit float constant operator from a float
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
} // namespace tosa
} // namespace mlir

View File

@ -240,10 +240,10 @@ public:
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenReciprocalOp>(op)) {
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -300,7 +300,8 @@ public:
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(
op)) {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
@ -1375,13 +1376,15 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
// not. In the former case, the type promotion rules are captured by the
// `getDefaultDtypeForTorchScalar` helper above. The latter case follows the
// rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
// rules in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
// `NumToTensor` falls in the latter case.
Type type = op.a().getType();
if (type.isa<Torch::FloatType>())
knowledge.dtype = Float64Type::get(op.getContext());
else if (type.isa<Torch::IntType>())
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
knowledge.dtype =
IntegerType::get(op.getContext(), 64, IntegerType::Signed);
return getLatticeElement(op.getResult()).join(knowledge);
}

View File

@ -167,3 +167,121 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func @test_reduce_mean_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
// CHECK: %[[RESHAPE_SUM:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[CONST:.*]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%dim0 = torch.constant.int 0
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<!torch.int>
%keepdims = torch.constant.bool false
%dtype = torch.constant.none
%0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @test_reduce_sum_dims$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
return %1 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @test_reduce_sum$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
%none = torch.constant.none
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func @test_reduce_all$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1>
}
// -----
// CHECK-LABEL: func @test_reduce_any_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
// CHECK: %[[ARG2:.*]] = torch.constant.bool false
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1>
return %0 : !torch.vtensor<[?,?,?],i1>
}
// -----
// CHECK-LABEL: func @test_reduce_any$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1>
}