mirror of https://github.com/llvm/torch-mlir
[tosa] Add Torch reduction operators
- Supports variants with multiple dims, one dim, all dime - Leverages legalize_common and legalize_utils code from TensorFlow-TOSA work Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/461/head
parent
ab6211184f
commit
c9c9b68d1f
|
@ -0,0 +1,64 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
// Lowers ReduceAll to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceAny to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceMin to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceMax to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceProd to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceSum to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
// Lowers ReduceMean to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims);
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
|
@ -0,0 +1,96 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
// Create a TOSA rescale op from input framework scaling, zero points and
|
||||
// rounding mode
|
||||
Value buildRescale(PatternRewriter &rewriter, Operation *op,
|
||||
ShapedType output_type, Value input_val, double scale,
|
||||
int64_t input_zp, int64_t output_zp, bool double_round,
|
||||
bool scale32);
|
||||
|
||||
// Creates TOSA rescale op with int32 output
|
||||
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
||||
Value input_val, double input_scale,
|
||||
int64_t input_zp);
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
||||
// Creates a TOSA operation and performs shape inference on the individual
|
||||
// op. This allows shape inference during the framework to TOSA lowering.
|
||||
template <typename TosaOp, typename... Args>
|
||||
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||
Args &&... args) {
|
||||
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
|
||||
|
||||
InferShapedTypeOpInterface shapeInterface =
|
||||
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
|
||||
if (!shapeInterface)
|
||||
return op;
|
||||
|
||||
SmallVector<ShapedTypeComponents> returnedShapes;
|
||||
if (shapeInterface
|
||||
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
|
||||
op->getOperands(), op->getAttrDictionary(),
|
||||
op->getRegions(), returnedShapes)
|
||||
.failed())
|
||||
return op;
|
||||
|
||||
// We need to use the element type of the existing result type to generate
|
||||
// the new result shaped type. This is because rescale can include a cast to
|
||||
// different bit-width types and does not have a TypeAttr to define the
|
||||
// target type.
|
||||
auto result = op->getResult(0);
|
||||
auto predictedShape = returnedShapes[0];
|
||||
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
|
||||
|
||||
// Compute the knowledge based on the inferred type.
|
||||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
||||
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
|
||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||
if (predictedShape.hasRank()) {
|
||||
for (auto dim : predictedShape.getDims()) {
|
||||
inferredKnowledge.sizes.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the new type based on the joined version.
|
||||
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
||||
auto new_ty = newKnowledge.getType();
|
||||
result.setType(new_ty);
|
||||
return op;
|
||||
}
|
||||
|
||||
template <typename TosaOp, typename... Args>
|
||||
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
|
||||
Type result_ty, Args &&... args) {
|
||||
auto result =
|
||||
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
|
||||
rewriter.replaceOp(op, result->getResults());
|
||||
}
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
|
@ -1,5 +1,7 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToTosa
|
||||
TorchToTosa.cpp
|
||||
TosaLegalizeUtils.cpp
|
||||
TosaLegalizeCommon.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
|
@ -252,6 +254,156 @@ LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
using ReductionConvFunc = llvm::Optional<Value> (*)(PatternRewriter &,
|
||||
Operation *,
|
||||
RankedTensorType, Value,
|
||||
ElementsAttr, bool);
|
||||
|
||||
// They all constitute a common form invoking the appropriate
|
||||
// converion function in TosaLegalizeCommon.cpp
|
||||
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
|
||||
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
// Each variant must implement corresponding parameter parsing options
|
||||
virtual LogicalResult readReduceDimsAndKeepDims(
|
||||
AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
ElementsAttr &reduceDimsAttr, bool &keepDims) const {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented reduce_dims and keep_dims parsing function");
|
||||
}
|
||||
|
||||
// Common rewriter for all reduction ops, calls the specific implementation of
|
||||
// readReduceDimsAndKeepDims() needed for the op variant.
|
||||
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
if (!outputTy)
|
||||
return op.emitError(
|
||||
"Only ranked tensor type outputs permitted for reduce_mean");
|
||||
|
||||
ElementsAttr reduceDimsAttr;
|
||||
bool keepDims;
|
||||
|
||||
if (failed(readReduceDimsAndKeepDims(op, adaptor, rewriter, reduceDimsAttr,
|
||||
keepDims)))
|
||||
return failure();
|
||||
|
||||
llvm::Optional<Value> result =
|
||||
ConversionFuncT(rewriter, op, outputTy, self, reduceDimsAttr, keepDims);
|
||||
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
// TBD - support dtype casting.
|
||||
|
||||
rewriter.replaceOp(op, {result.getValue()});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This reduction op legalization template handles op variants that have
|
||||
// explicit reduce_dims dimensions (provided as a list) and keep_dims
|
||||
// parameters.
|
||||
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
|
||||
class ConvertAtenMultipleDimsReductionOp
|
||||
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
|
||||
using ConvertAtenReductionOp<AtenOpT,
|
||||
ConversionFuncT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ElementsAttr &reduceDimsAttr,
|
||||
bool &keepDims) const {
|
||||
SmallVector<int64_t, 4> reduceDims;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantIntList(reduceDims)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const keepdim parameter unsupported");
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This reduction op legalization template handles op variants that reduce in
|
||||
// only one explicit dim which is provided as a number (rather than a list), and
|
||||
// a keep_dims parameter.
|
||||
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
|
||||
class ConvertAtenOneDimReductionOp
|
||||
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
|
||||
using ConvertAtenReductionOp<AtenOpT,
|
||||
ConversionFuncT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ElementsAttr &reduceDimsAttr,
|
||||
bool &keepDims) const {
|
||||
int64_t reduceDim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef({reduceDim}));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const keepdim parameter unsupported");
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This reduction op legalization template handles op variants that reduce all
|
||||
// dims does not keep dims.
|
||||
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
|
||||
class ConvertAtenAllDimsReductionOp
|
||||
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
|
||||
public:
|
||||
using ConvertAtenReductionOp<AtenOpT,
|
||||
ConversionFuncT>::ConvertAtenReductionOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ElementsAttr &reduceDimsAttr,
|
||||
bool &keepDims) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
|
||||
// Select all dims to reduce
|
||||
SmallVector<int64_t, 4> reduceDims;
|
||||
for (int64_t i = 0; i < selfTy.getRank(); i++)
|
||||
reduceDims.push_back(i);
|
||||
int64_t N = selfTy.getRank();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
keepDims = false;
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -300,6 +452,36 @@ public:
|
|||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
|
||||
mlir::tosa::convertReduceMeanOp)
|
||||
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
typeConverter, context);
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
|
|
|
@ -0,0 +1,314 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
// Common function for lowering reduce operations to TOSA ops.
|
||||
template <typename T>
|
||||
llvm::Optional<Value> convertReduceOpCommon(
|
||||
PatternRewriter &rewriter, Operation *op, RankedTensorType output_type,
|
||||
Value input_value, ElementsAttr axes_elems, bool keep_dims,
|
||||
Type reduce_element_type, bool is_quantized, double input_scale,
|
||||
int64_t input_zp, double output_scale, int64_t output_zp) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
ArrayRef<int64_t> input_shape = input_type.getShape();
|
||||
ArrayRef<int64_t> output_shape = output_type.getShape();
|
||||
auto input_rank = input_shape.size();
|
||||
Value val = input_value;
|
||||
|
||||
if (axes_elems.getNumElements() == 0) {
|
||||
// No axes means return the original tensor.
|
||||
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
|
||||
rewriter, op->getLoc(), output_type, val);
|
||||
val = identity_op.getResult();
|
||||
} else {
|
||||
// Reduce along each axis
|
||||
SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
|
||||
|
||||
if (is_quantized) {
|
||||
val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
|
||||
}
|
||||
|
||||
for (int i = 0; i < axes_elems.getNumElements(); i++) {
|
||||
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
|
||||
if (axis_val < 0)
|
||||
axis_val += input_rank;
|
||||
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
|
||||
|
||||
shape_vec[axis_val] = 1;
|
||||
RankedTensorType reduce_type =
|
||||
RankedTensorType::get(shape_vec, reduce_element_type);
|
||||
|
||||
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
|
||||
val, axis_attr);
|
||||
|
||||
val = reduce_op.getResult();
|
||||
}
|
||||
|
||||
if (is_quantized) {
|
||||
RankedTensorType output_rescale_type =
|
||||
RankedTensorType::get(shape_vec, output_type.getElementType());
|
||||
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
|
||||
0, output_zp, false, true);
|
||||
}
|
||||
|
||||
// Optionally squeeze out the reduced axes.
|
||||
if (!keep_dims) {
|
||||
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(), output_type, val,
|
||||
rewriter.getI64ArrayAttr(output_shape));
|
||||
val = reshape_op.getResult();
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
// Lowers ReduceAll to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceAllOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||||
}
|
||||
|
||||
// Lowers ReduceAny to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceAnyOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||||
}
|
||||
|
||||
// Lowers ReduceMin to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceMinOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||||
}
|
||||
|
||||
// Lowers ReduceMax to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceMaxOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||||
}
|
||||
|
||||
// Lowers ReduceProd to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
|
||||
if (input_is_qtype || output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
|
||||
"be all floating-point.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceProdOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
|
||||
}
|
||||
|
||||
// Lowers ReduceSum to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
|
||||
if (input_is_qtype != output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||
"be all quantized or all floating-point.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
double input_scale = 1.0f;
|
||||
double output_scale = 1.0f;
|
||||
int64_t input_zp = 0;
|
||||
int64_t output_zp = 0;
|
||||
Type reduce_element_type = input_type.getElementType();
|
||||
|
||||
if (input_is_qtype) {
|
||||
auto input_qtype =
|
||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
auto output_qtype =
|
||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
|
||||
int32_t input_shift = 20;
|
||||
|
||||
input_scale =
|
||||
static_cast<double>(1 << input_shift) * input_qtype.getScale();
|
||||
output_scale =
|
||||
1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
|
||||
|
||||
input_zp = input_qtype.getZeroPoint();
|
||||
output_zp = output_qtype.getZeroPoint();
|
||||
reduce_element_type = rewriter.getI32Type();
|
||||
}
|
||||
|
||||
return convertReduceOpCommon<tosa::ReduceSumOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||||
output_zp);
|
||||
}
|
||||
|
||||
// Lowers ReduceMean to a sequence of TOSA ops.
|
||||
llvm::Optional<Value>
|
||||
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||
RankedTensorType output_type, Value input_value,
|
||||
ElementsAttr axes_elems, bool keep_dims) {
|
||||
// reduce_mean is lowered as followed:
|
||||
// op1 = reduce_sum(input)
|
||||
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
|
||||
|
||||
RankedTensorType input_type =
|
||||
input_value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type)
|
||||
return llvm::None;
|
||||
|
||||
bool input_is_qtype =
|
||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
bool output_is_qtype =
|
||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
||||
|
||||
if (input_is_qtype != output_is_qtype) {
|
||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||
"be all quantized or all floating-point.");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Only supports float type mean() if it's non-quantized
|
||||
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
|
||||
op->emitWarning(
|
||||
"Failed convertReduceMean: input unquantized type but output element "
|
||||
"not FloatType!");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
int64_t input_rank = input_type.getRank();
|
||||
int64_t num_elems_on_reduced_axis = 1;
|
||||
for (int i = 0; i < axes_elems.getNumElements(); i++) {
|
||||
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
|
||||
if (axis_val < 0)
|
||||
axis_val += input_rank;
|
||||
num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
|
||||
}
|
||||
double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
|
||||
|
||||
double input_scale = 1.0f;
|
||||
double output_scale = 1.0f;
|
||||
int64_t input_zp = 0;
|
||||
int64_t output_zp = 0;
|
||||
Type reduce_element_type = input_type.getElementType();
|
||||
|
||||
if (input_is_qtype) {
|
||||
auto input_qtype =
|
||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
auto output_qtype =
|
||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
||||
|
||||
// Combine 'div_scale' as part of output rescale
|
||||
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
|
||||
|
||||
input_zp = input_qtype.getZeroPoint();
|
||||
output_zp = output_qtype.getZeroPoint();
|
||||
reduce_element_type = rewriter.getI32Type();
|
||||
}
|
||||
|
||||
auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
|
||||
rewriter, op, output_type, input_value, axes_elems, keep_dims,
|
||||
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||||
output_zp);
|
||||
|
||||
if (!val.hasValue())
|
||||
return llvm::None;
|
||||
|
||||
if (!input_is_qtype) {
|
||||
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
|
||||
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
|
||||
val.getValue(), div_const, 0)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
|
@ -0,0 +1,67 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
// Create a TOSA rescale op from input framework tensor, zero points and
|
||||
// rounding mode
|
||||
Value buildRescale(PatternRewriter &rewriter, Operation *op,
|
||||
ShapedType output_type, Value input_val, double scale,
|
||||
int64_t input_zp, int64_t output_zp, bool double_round,
|
||||
bool scale32) {
|
||||
int32_t multiplier;
|
||||
int32_t shift;
|
||||
|
||||
int32_t scale_width = scale32 ? 32 : 16;
|
||||
|
||||
computeMultiplierAndShift(scale, multiplier, shift, scale_width);
|
||||
|
||||
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
|
||||
rewriter, op->getLoc(), output_type, input_val,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
|
||||
rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
|
||||
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
return rescale_op.getResult();
|
||||
}
|
||||
|
||||
// Creates TOSA rescale op with int32 output
|
||||
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
||||
Value input_val, double input_scale,
|
||||
int64_t input_zp) {
|
||||
// Output is always int32 type
|
||||
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
|
||||
assert(input_type);
|
||||
auto output_type = input_type.clone(rewriter.getI32Type());
|
||||
|
||||
return buildRescale(rewriter, op, output_type, input_val, input_scale,
|
||||
input_zp, 0, false, true);
|
||||
}
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
|
@ -240,10 +240,10 @@ public:
|
|||
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
|
||||
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
|
||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenReciprocalOp>(op)) {
|
||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
|
||||
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -300,7 +300,8 @@ public:
|
|||
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
||||
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
|
||||
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
|
||||
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
|
||||
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(
|
||||
op)) {
|
||||
return visitBinaryTensorScalarOp(op, operands);
|
||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||
|
@ -1375,13 +1376,15 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
|||
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
||||
// not. In the former case, the type promotion rules are captured by the
|
||||
// `getDefaultDtypeForTorchScalar` helper above. The latter case follows the
|
||||
// rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
|
||||
// rules in
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
|
||||
// `NumToTensor` falls in the latter case.
|
||||
Type type = op.a().getType();
|
||||
if (type.isa<Torch::FloatType>())
|
||||
knowledge.dtype = Float64Type::get(op.getContext());
|
||||
else if (type.isa<Torch::IntType>())
|
||||
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
|
|
@ -167,3 +167,121 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
|
|||
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_mean_dim$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPE_SUM:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[CONST:.*]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%dim0 = torch.constant.int 0
|
||||
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<!torch.int>
|
||||
%keepdims = torch.constant.bool false
|
||||
%dtype = torch.constant.none
|
||||
%0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_sum_dims$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<!torch.int>
|
||||
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_sum$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
|
||||
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32>
|
||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32>
|
||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||
func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_all$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
|
||||
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
||||
func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||
%0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
|
||||
return %0 : !torch.vtensor<[1],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_any_dim$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
|
||||
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ARG2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
|
||||
func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_any$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
|
||||
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
||||
func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||
%0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
|
||||
return %0 : !torch.vtensor<[1],i1>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue