2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
|
|
|
|
|
|
#include "PopulatePatterns.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2024-04-02 16:33:30 +08:00
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/IR/Matchers.h"
|
2023-12-02 08:38:21 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
#include "llvm/ADT/APSInt.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
namespace {
|
2023-12-05 23:16:35 +08:00
|
|
|
// Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an
|
|
|
|
// linalg.indexed_generic op, producing two output buffers.
|
2022-03-11 01:54:13 +08:00
|
|
|
//
|
2023-12-05 23:16:35 +08:00
|
|
|
// The first output buffer contains the maximum (minium) value found. It is
|
|
|
|
// initialized to the minimum (maximum) representable value of the input
|
|
|
|
// element type.
|
2022-03-11 01:54:13 +08:00
|
|
|
//
|
2023-12-05 23:16:35 +08:00
|
|
|
// The second output buffer contains the index of the found maximum (minimum)
|
|
|
|
// value. It is initialized to 0 and is resulting integer type.
|
2022-03-11 01:54:13 +08:00
|
|
|
//
|
2023-12-05 23:16:35 +08:00
|
|
|
// The indexed_generic op updates both the maximum (minimum) value and index
|
|
|
|
// if the current value exceeds the running max (min).
|
|
|
|
template <typename OpTy>
|
|
|
|
class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
|
2022-03-11 01:54:13 +08:00
|
|
|
public:
|
2023-12-05 23:16:35 +08:00
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
using OpConversionPattern<OpTy>::getTypeConverter;
|
|
|
|
|
|
|
|
using OpAdaptor = typename OpTy::Adaptor;
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
LogicalResult
|
2023-12-05 23:16:35 +08:00
|
|
|
matchAndRewrite(OpTy op, OpAdaptor adaptor,
|
2022-03-11 01:54:13 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2023-12-05 23:16:35 +08:00
|
|
|
static_assert(std::is_same<OpTy, AtenMaxDimOp>() ||
|
|
|
|
std::is_same<OpTy, AtenMinDimOp>());
|
|
|
|
constexpr bool isMax = std::is_same<OpTy, AtenMaxDimOp>();
|
|
|
|
const llvm::StringRef opName = op->getName().getStringRef();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2023-12-05 23:16:35 +08:00
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-02-28 14:48:07 +08:00
|
|
|
auto typec = this->getTypeConverter();
|
|
|
|
auto valResultType =
|
|
|
|
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
|
|
|
auto idxResultType =
|
|
|
|
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
2024-02-28 14:48:07 +08:00
|
|
|
Type idxElementType =
|
|
|
|
getElementTypeOrSelf(typec->convertType(idxResultType));
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<IntegerType>(idxElementType))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-05 23:16:35 +08:00
|
|
|
op, opName + " to linalg.* requires integer-like result type");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
bool keepDim = false;
|
2023-12-05 23:16:35 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-05 23:16:35 +08:00
|
|
|
op, opName + " requires boolean value for keepdim");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
int64_t dim;
|
2023-12-05 23:16:35 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-05 23:16:35 +08:00
|
|
|
op, opName + " to linalg.* requires int value for Dim");
|
2022-03-11 01:54:13 +08:00
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
if (!isValidDim(dim, inputType.getRank()))
|
2023-12-05 23:16:35 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Type inElementType = inputType.getElementType();
|
2024-03-07 08:48:21 +08:00
|
|
|
bool isUnsigned = false;
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::FloatType>(inElementType)) {
|
|
|
|
if (isa<mlir::IntegerType>(inElementType)) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto integerTy = dyn_cast<mlir::IntegerType>(
|
|
|
|
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
|
2024-03-07 08:48:21 +08:00
|
|
|
isUnsigned = integerTy.isUnsigned();
|
2023-02-06 19:52:04 +08:00
|
|
|
} else {
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-05 23:16:35 +08:00
|
|
|
op, opName + " to linalg.* requires Float or Integer "
|
2024-01-30 01:59:33 +08:00
|
|
|
"input element type");
|
2023-02-06 19:52:04 +08:00
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Constant op to account for the reduction along dim.
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
|
|
|
if (dim != i) {
|
|
|
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
|
|
|
resultShape.push_back(currentDimSize);
|
2024-02-28 14:48:07 +08:00
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
// First fill the output buffer for the index.
|
|
|
|
Value filledTensorIdx =
|
|
|
|
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
|
|
|
|
|
2023-12-05 23:16:35 +08:00
|
|
|
// Second fill the output buffer for the running max or min.
|
|
|
|
Value initTensorVal = rewriter.create<tensor::EmptyOp>(
|
2022-10-18 12:22:53 +08:00
|
|
|
loc, getAsOpFoldResult(resultShape), inElementType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2023-12-05 23:16:35 +08:00
|
|
|
Value fillValue;
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(inElementType)) {
|
2023-12-05 23:16:35 +08:00
|
|
|
fillValue = rewriter.create<arith::ConstantOp>(
|
2024-04-11 21:47:35 +08:00
|
|
|
loc, rewriter.getFloatAttr(
|
|
|
|
inElementType,
|
|
|
|
APFloat::getInf(
|
|
|
|
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
|
|
|
|
/*Negative=*/isMax)));
|
2024-03-07 08:48:21 +08:00
|
|
|
} else if (!isUnsigned) {
|
2024-04-11 21:47:35 +08:00
|
|
|
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
|
2023-12-05 23:16:35 +08:00
|
|
|
auto init = isMax ? APSInt::getSignedMinValue(width)
|
|
|
|
: APSInt::getSignedMaxValue(width);
|
|
|
|
fillValue = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
2024-03-07 08:48:21 +08:00
|
|
|
} else if (isUnsigned) {
|
2024-04-11 21:47:35 +08:00
|
|
|
auto width = cast<mlir::IntegerType>(inElementType).getWidth();
|
2024-03-07 08:48:21 +08:00
|
|
|
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
|
|
|
|
fillValue = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
2023-02-06 19:52:04 +08:00
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2023-12-05 23:16:35 +08:00
|
|
|
Value filledTensorVal =
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2024-02-28 14:48:07 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
inputType.getRank(), utils::IteratorType::parallel);
|
|
|
|
iteratorTypes[dim] = utils::IteratorType::reduction;
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
// Create the affine expressions that will be used to
|
|
|
|
// iterate over the input and output tensors.
|
|
|
|
// Here we also set the type of iterator: parallel or reduction.
|
2024-02-28 14:48:07 +08:00
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<AffineExpr> exprs;
|
|
|
|
SmallVector<AffineExpr> resultExprs;
|
2022-11-29 20:33:31 +08:00
|
|
|
for (auto size :
|
|
|
|
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
2022-03-11 01:54:13 +08:00
|
|
|
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
2024-02-28 14:48:07 +08:00
|
|
|
if (unsigned(dim) != size.index())
|
2022-03-11 01:54:13 +08:00
|
|
|
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
2024-02-10 06:07:49 +08:00
|
|
|
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
|
|
|
|
rewriter.getContext());
|
2022-03-11 01:54:13 +08:00
|
|
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc,
|
2023-12-05 23:16:35 +08:00
|
|
|
ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}),
|
|
|
|
input, ValueRange({filledTensorVal, filledTensorIdx}), maps,
|
2022-03-11 01:54:13 +08:00
|
|
|
iteratorTypes,
|
|
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
|
|
|
ValueRange blockArgs) {
|
|
|
|
Value newValue = blockArgs[0];
|
|
|
|
Value oldValue = blockArgs[1];
|
|
|
|
Value oldIndex = blockArgs[2];
|
|
|
|
|
|
|
|
Value newIndex = rewriter.create<arith::IndexCastOp>(
|
|
|
|
nestedLoc, oldIndex.getType(),
|
|
|
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
|
|
|
|
2023-12-05 23:16:35 +08:00
|
|
|
Value resultVal, predicate;
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(inElementType)) {
|
2024-01-30 01:59:33 +08:00
|
|
|
arith::CmpFPredicate predType;
|
2023-12-08 02:08:17 +08:00
|
|
|
if (isMax) {
|
2023-12-05 23:16:35 +08:00
|
|
|
predType = arith::CmpFPredicate::OGT;
|
|
|
|
resultVal = rewriter.create<arith::MaximumFOp>(
|
|
|
|
nestedLoc, newValue, oldValue);
|
|
|
|
} else {
|
|
|
|
predType = arith::CmpFPredicate::OLT;
|
|
|
|
resultVal = rewriter.create<arith::MinimumFOp>(
|
|
|
|
nestedLoc, newValue, oldValue);
|
|
|
|
}
|
|
|
|
|
|
|
|
predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
|
2024-01-30 01:59:33 +08:00
|
|
|
newValue, oldValue);
|
2023-02-06 19:52:04 +08:00
|
|
|
} else {
|
2023-12-05 23:16:35 +08:00
|
|
|
arith::CmpIPredicate predType;
|
2023-12-08 02:08:17 +08:00
|
|
|
if (isMax) {
|
2024-03-07 08:48:21 +08:00
|
|
|
predType = isUnsigned ? arith::CmpIPredicate::ugt
|
|
|
|
: arith::CmpIPredicate::sgt;
|
|
|
|
if (isUnsigned) {
|
|
|
|
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
|
|
|
|
oldValue);
|
|
|
|
} else {
|
|
|
|
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
|
|
|
oldValue);
|
|
|
|
}
|
2023-12-05 23:16:35 +08:00
|
|
|
} else {
|
2024-03-07 08:48:21 +08:00
|
|
|
predType = isUnsigned ? arith::CmpIPredicate::ult
|
|
|
|
: arith::CmpIPredicate::slt;
|
|
|
|
if (isUnsigned) {
|
|
|
|
resultVal = rewriter.create<arith::MinUIOp>(nestedLoc, newValue,
|
|
|
|
oldValue);
|
|
|
|
} else {
|
|
|
|
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
|
|
|
|
oldValue);
|
|
|
|
}
|
2023-12-05 23:16:35 +08:00
|
|
|
}
|
|
|
|
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
|
|
|
newValue, oldValue);
|
2023-02-06 19:52:04 +08:00
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
auto resultIndex = rewriter.create<arith::SelectOp>(
|
|
|
|
nestedLoc, predicate, newIndex, oldIndex);
|
|
|
|
nestedBuilder.create<linalg::YieldOp>(
|
2023-12-05 23:16:35 +08:00
|
|
|
nestedLoc, ValueRange({resultVal, resultIndex}));
|
2022-03-11 01:54:13 +08:00
|
|
|
});
|
|
|
|
|
2024-02-28 14:48:07 +08:00
|
|
|
if (!keepDim) {
|
|
|
|
Value rVal = rewriter.create<tensor::CastOp>(loc, valResultType,
|
|
|
|
linalgOp.getResult(0));
|
|
|
|
Value rIdx = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
|
|
|
linalgOp.getResult(1));
|
|
|
|
llvm::SmallVector<Value> res{rVal, rIdx};
|
|
|
|
rewriter.replaceOp(op, res);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
|
|
|
|
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
|
|
|
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
|
|
|
valShape[i] = valShape[i + 1];
|
|
|
|
idxShape[i] = idxShape[i + 1];
|
|
|
|
}
|
|
|
|
|
|
|
|
valShape.resize(valShape.size() - 1);
|
|
|
|
idxShape.resize(idxShape.size() - 1);
|
|
|
|
|
|
|
|
Value rVal = rewriter.create<tensor::CastOp>(
|
|
|
|
loc, valResultType.clone(valShape), linalgOp.getResult(0));
|
|
|
|
Value rIdx = rewriter.create<tensor::CastOp>(
|
|
|
|
loc, idxResultType.clone(idxShape), linalgOp.getResult(1));
|
|
|
|
|
|
|
|
SmallVector<ReassociationIndices> reassociation(valShape.size());
|
|
|
|
if (reassociation.size() > 0) {
|
|
|
|
for (int i = 0; i < dim; ++i)
|
|
|
|
reassociation[i].push_back(i);
|
|
|
|
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
|
|
|
|
for (int i = dim, s = reassociation.size(); i < s; ++i)
|
|
|
|
reassociation[i].push_back(i + 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
valShape.push_back(0);
|
|
|
|
idxShape.push_back(0);
|
|
|
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
|
|
|
valShape[i + 1] = valShape[i];
|
|
|
|
idxShape[i + 1] = idxShape[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
valShape[dim] = 1;
|
|
|
|
idxShape[dim] = 1;
|
|
|
|
|
|
|
|
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
|
|
|
|
loc, valResultType, rVal, reassociation);
|
|
|
|
|
|
|
|
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
|
|
|
|
loc, idxResultType, rIdx, reassociation);
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
|
|
|
|
rewriter.replaceOp(op, unsqueezes);
|
2022-03-11 01:54:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
2023-12-05 23:16:35 +08:00
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
} // namespace
|
|
|
|
|
2024-04-02 16:33:30 +08:00
|
|
|
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
|
|
|
|
Type resultElementType) {
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<mlir::ComplexType>(elem.getType())) {
|
2024-04-02 16:33:30 +08:00
|
|
|
return b.create<complex::AbsOp>(loc, elem);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
|
|
|
return b.create<math::AbsFOp>(loc, self);
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|
|
|
Operation *op, Type elementType) {
|
2022-03-11 01:54:13 +08:00
|
|
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
|
|
|
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
|
|
|
|
2024-04-24 11:14:04 +08:00
|
|
|
if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementType))
|
2023-09-06 04:38:51 +08:00
|
|
|
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 1.0));
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(elementType))
|
2023-09-06 04:38:51 +08:00
|
|
|
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(elementType, 1));
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
if (isa<AtenMaxOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementType))
|
2022-03-11 01:54:13 +08:00
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getFloatAttr(
|
|
|
|
elementType,
|
2023-05-09 00:17:49 +08:00
|
|
|
APFloat::getInf(
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
2022-03-11 01:54:13 +08:00
|
|
|
/*Negative=*/true)));
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(elementType) &&
|
2022-03-11 01:54:13 +08:00
|
|
|
elementType.getIntOrFloatBitWidth() != 8)
|
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getIntegerAttr(elementType,
|
|
|
|
APSInt::getSignedMinValue(
|
|
|
|
elementType.getIntOrFloatBitWidth())));
|
|
|
|
}
|
|
|
|
|
2023-08-30 01:12:41 +08:00
|
|
|
if (isa<AtenMinOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementType))
|
2023-08-30 01:12:41 +08:00
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getFloatAttr(
|
|
|
|
elementType,
|
|
|
|
APFloat::getInf(
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
2023-08-30 01:12:41 +08:00
|
|
|
/*Negative=*/false)));
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(elementType) &&
|
2023-08-30 01:12:41 +08:00
|
|
|
elementType.getIntOrFloatBitWidth() != 8)
|
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getIntegerAttr(elementType,
|
|
|
|
APSInt::getSignedMaxValue(
|
|
|
|
elementType.getIntOrFloatBitWidth())));
|
|
|
|
}
|
|
|
|
|
2024-02-27 00:46:56 +08:00
|
|
|
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
|
|
|
isa<AtenNormScalarOp>(op))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
|
|
|
|
2024-04-25 11:15:52 +08:00
|
|
|
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
2024-02-08 04:34:52 +08:00
|
|
|
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
|
|
|
|
}
|
|
|
|
|
2024-04-25 11:15:52 +08:00
|
|
|
if (isa<AtenAnyOp>(op)) {
|
|
|
|
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
op->emitError("unimplemented lowering in createInitElementForReduceOp");
|
2022-03-11 01:54:13 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|
|
|
ValueRange payloadArgs,
|
|
|
|
Operation *op,
|
|
|
|
ArrayRef<Value> operands,
|
|
|
|
Type resultElementType) {
|
2022-03-11 01:54:13 +08:00
|
|
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
|
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(resultElementType))
|
2022-03-11 01:54:13 +08:00
|
|
|
return b.create<arith::AddFOp>(loc, self, result);
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(resultElementType))
|
2022-03-11 01:54:13 +08:00
|
|
|
return b.create<arith::AddIOp>(loc, self, result);
|
2024-04-24 11:14:04 +08:00
|
|
|
} else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
|
2023-09-06 04:38:51 +08:00
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(resultElementType))
|
2023-09-06 04:38:51 +08:00
|
|
|
return b.create<arith::MulFOp>(loc, self, result);
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(resultElementType))
|
2023-09-06 04:38:51 +08:00
|
|
|
return b.create<arith::MulIOp>(loc, self, result);
|
2022-03-11 01:54:13 +08:00
|
|
|
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(resultElementType))
|
2023-09-20 01:50:53 +08:00
|
|
|
return b.create<arith::MaximumFOp>(loc, self, result);
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
2024-05-31 14:45:13 +08:00
|
|
|
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
|
|
|
cast<BaseTensorType>(max.getSelf().getType()).getDtype());
|
2022-03-11 01:54:13 +08:00
|
|
|
if (intType.isUnsigned())
|
|
|
|
return b.create<arith::MaxUIOp>(loc, self, result);
|
|
|
|
if (intType.isSigned())
|
|
|
|
return b.create<arith::MaxSIOp>(loc, self, result);
|
|
|
|
}
|
2023-08-30 01:12:41 +08:00
|
|
|
} else if (auto min = dyn_cast<AtenMinOp>(op)) {
|
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(resultElementType))
|
2023-09-20 01:50:53 +08:00
|
|
|
return b.create<arith::MinimumFOp>(loc, self, result);
|
2024-04-11 21:47:35 +08:00
|
|
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
2024-05-31 14:45:13 +08:00
|
|
|
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
|
|
|
cast<BaseTensorType>(min.getSelf().getType()).getDtype());
|
2023-08-30 01:12:41 +08:00
|
|
|
if (intType.isUnsigned())
|
|
|
|
return b.create<arith::MinUIOp>(loc, self, result);
|
|
|
|
if (intType.isSigned())
|
|
|
|
return b.create<arith::MinSIOp>(loc, self, result);
|
|
|
|
}
|
2024-02-27 00:46:56 +08:00
|
|
|
} else if (isa<AtenNormScalarOp>(op)) {
|
|
|
|
// This creates payload for only the first of the two linalg.generic ops.
|
|
|
|
// TODO: Short-circuit operations if `p` is zero or one.
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
|
|
|
|
AtenNormScalarOp::Adaptor adaptor(operands);
|
|
|
|
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
|
2024-04-02 16:33:30 +08:00
|
|
|
|
|
|
|
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
|
2024-02-27 00:46:56 +08:00
|
|
|
auto pow = b.create<math::PowFOp>(loc, abs, p);
|
|
|
|
return b.create<arith::AddFOp>(loc, pow, result);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
|
|
|
// This creates payload for only the first of the two linalg.generic ops.
|
|
|
|
// TODO: Short-circuit operations if `ord` is zero or one.
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-02 16:33:30 +08:00
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
2024-01-30 01:59:33 +08:00
|
|
|
Value ord =
|
|
|
|
convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
|
2024-04-02 16:33:30 +08:00
|
|
|
|
|
|
|
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
|
|
|
return b.create<arith::AddFOp>(loc, pow, result);
|
2022-09-08 10:15:36 +08:00
|
|
|
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
2024-04-02 16:33:30 +08:00
|
|
|
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr twoAttr = b.getFloatAttr(resultElementType, 2.0);
|
2022-09-08 10:15:36 +08:00
|
|
|
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
|
2024-04-02 16:33:30 +08:00
|
|
|
|
|
|
|
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
|
2022-09-08 10:15:36 +08:00
|
|
|
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
|
|
|
return b.create<arith::AddFOp>(loc, pow, result);
|
2024-04-25 11:15:52 +08:00
|
|
|
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
|
|
|
return b.create<arith::AndIOp>(loc, self, result);
|
|
|
|
} else if (isa<AtenAnyOp>(op)) {
|
2024-02-08 04:34:52 +08:00
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
2024-04-25 11:15:52 +08:00
|
|
|
return b.create<arith::OrIOp>(loc, self, result);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
|
2022-03-11 01:54:13 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertReductionOp : public ConversionPattern {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
private:
|
|
|
|
/// Given a reduction operation that has the `keepdim` attribute and the
|
|
|
|
/// (optional) `dim` attribute, return the source tensor operand and the
|
|
|
|
/// literal values of the attributes or failure otherwise.
|
|
|
|
template <typename T>
|
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo>
|
|
|
|
computeReductionOpInfoForDimVariantOp(
|
|
|
|
T op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
|
|
|
typename T::Adaptor adaptor(operands);
|
2022-12-08 04:20:41 +08:00
|
|
|
opInfo.tensorOperand = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"`keepdim` must be a constant bool");
|
|
|
|
|
|
|
|
SmallVector<int64_t> dimList;
|
2023-09-06 04:38:51 +08:00
|
|
|
int64_t dim;
|
2024-04-28 05:00:56 +08:00
|
|
|
bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
|
2022-12-08 04:20:41 +08:00
|
|
|
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// Fix negative dimensions, if any, before adding to the list.
|
|
|
|
for (int64_t dim : dimList) {
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
// Drop invalid dimensions
|
|
|
|
if (isValidDim(dim, inputType.getRank()))
|
|
|
|
opInfo.dimSet.insert(dim);
|
|
|
|
}
|
2022-07-28 22:24:24 +08:00
|
|
|
if (dimList.empty())
|
|
|
|
isNoneOrEmptyDimList = true;
|
2023-09-06 04:38:51 +08:00
|
|
|
} else if (matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
if (!isValidDim(dim, inputType.getRank()))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "`dim` argument must be valid, invalid received.");
|
|
|
|
opInfo.dimSet.insert(dim);
|
2022-07-28 22:24:24 +08:00
|
|
|
} else if (!isNoneOrEmptyDimList) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "`dim` argument must be a constant int list or None");
|
|
|
|
}
|
|
|
|
if (isNoneOrEmptyDimList) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If no dimensions were specified, reduce along all dimensions
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
|
|
opInfo.dimSet.insert(i);
|
|
|
|
}
|
|
|
|
|
|
|
|
return opInfo;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Given a reduction operation, return the source tensor operand and the
|
|
|
|
/// literal values of the `keepdim` and `dim` attributes, if any, or failure
|
|
|
|
/// otherwise.
|
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo>
|
|
|
|
computeReductionOpInfo(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
|
|
|
|
2024-04-25 11:15:52 +08:00
|
|
|
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
|
|
|
|
AtenNormScalarOp>(op)) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
opInfo.tensorOperand = operands[0];
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
|
2024-04-25 11:15:52 +08:00
|
|
|
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
|
|
|
|
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
|
|
opInfo.dimSet.insert(i);
|
|
|
|
|
|
|
|
return opInfo;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto sumOp = dyn_cast<AtenSumDimIntListOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter);
|
|
|
|
|
2023-09-06 04:38:51 +08:00
|
|
|
if (auto prodOp = dyn_cast<AtenProdDimIntOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(prodOp, operands, rewriter);
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
|
|
|
|
2024-02-08 04:34:52 +08:00
|
|
|
if (auto allOp = dyn_cast<AtenAllDimOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate a linalg.generic operation for pointwise exponentiation of each
|
|
|
|
/// element.
|
|
|
|
Value createElementwiseExp(Location loc, Type elemType, Value exponent,
|
|
|
|
Value inputTensor,
|
|
|
|
const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
bool err = false;
|
|
|
|
auto powBodyBuilder = [&](OpBuilder &builder, Location loc,
|
|
|
|
ValueRange payloadArgs) {
|
|
|
|
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType);
|
|
|
|
auto result = builder.create<math::PowFOp>(loc, elem, exponent);
|
|
|
|
if (result)
|
|
|
|
builder.create<linalg::YieldOp>(loc, Value{result});
|
|
|
|
err = !result;
|
|
|
|
};
|
|
|
|
|
|
|
|
Value powOp = torch_to_linalg::createElementwiseLinalgGeneric(
|
|
|
|
rewriter, loc, {inputTensor}, elemType, powBodyBuilder);
|
|
|
|
return err ? Value{} : powOp;
|
|
|
|
}
|
|
|
|
|
2024-02-27 00:46:56 +08:00
|
|
|
template <typename TOp>
|
|
|
|
FailureOr<Value>
|
|
|
|
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
|
|
|
|
Value ordOp, Value firstReduction,
|
|
|
|
const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// Cast `ord` to float so that we can readily pass it math.powf.
|
|
|
|
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
|
|
|
|
|
|
|
|
// TODO: Add support for ord = {0, +inf, -inf}.
|
|
|
|
auto epsilon = 1e-5;
|
|
|
|
auto ordLiteral = 0.0;
|
|
|
|
if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) &&
|
|
|
|
fabs(ordLiteral) < epsilon)
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm");
|
|
|
|
|
|
|
|
if (std::isinf(ordLiteral))
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf");
|
|
|
|
|
|
|
|
// Raise each summed value to the inverse of the order of the norm.
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr oneAttr = rewriter.getFloatAttr(elemType, 1.0);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto oneValue = rewriter.create<arith::ConstantOp>(loc, oneAttr);
|
|
|
|
auto inverseOrdValue =
|
|
|
|
rewriter.create<arith::DivFOp>(loc, oneValue, ordValue);
|
|
|
|
|
|
|
|
// Use the results of the first reduction operation from above to generate
|
|
|
|
// a second reduction operation.
|
|
|
|
Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue,
|
|
|
|
firstReduction, opInfo, rewriter);
|
|
|
|
if (!reduceOp)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to create linalg.generic operation for element-wise "
|
|
|
|
"exponentiation");
|
|
|
|
|
|
|
|
return reduceOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate a linalg.generic operation for a reduction.
|
|
|
|
Value createReductionOp(Location loc, Type elemType, Operation *op,
|
|
|
|
ArrayRef<Value> operands,
|
|
|
|
const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
bool err = false;
|
|
|
|
auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc,
|
|
|
|
ValueRange payloadArgs) {
|
|
|
|
Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs,
|
|
|
|
op, operands, elemType);
|
|
|
|
if (result)
|
|
|
|
builder.create<linalg::YieldOp>(loc, result);
|
|
|
|
err = !result;
|
|
|
|
};
|
|
|
|
|
|
|
|
Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType);
|
|
|
|
Value reduceOp = torch_to_linalg::createReductionLinalgGeneric(
|
|
|
|
rewriter, loc, opInfo, initElem, reductionBodyBuilder);
|
|
|
|
return err ? Value{} : reduceOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Depending on the operation, check validity of the result's element type.
|
|
|
|
LogicalResult
|
|
|
|
validateReductionElementType(Operation *op, Type elemType,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2024-02-27 00:46:56 +08:00
|
|
|
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
|
|
|
isa<AtenNormScalarOp>(op)) &&
|
2024-04-11 21:47:35 +08:00
|
|
|
!isa<mlir::FloatType>(elemType))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only float types are valid for vector norm ops");
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<AtenAllDimOp>(op) && isa<mlir::IntegerType>(elemType) &&
|
2024-02-08 04:34:52 +08:00
|
|
|
elemType.getIntOrFloatBitWidth() == 8)
|
|
|
|
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
|
2024-02-27 00:46:56 +08:00
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// No checks for all other reduction operations
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
public:
|
|
|
|
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
|
|
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
|
|
|
context) {}
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid operand or result types to use with linalg on tensors");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo> opInfo =
|
|
|
|
computeReductionOpInfo(op, operands, rewriter);
|
|
|
|
if (failed(opInfo))
|
|
|
|
return opInfo;
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
2024-05-31 14:45:13 +08:00
|
|
|
auto resultType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
Type elemType = resultType.getElementType();
|
|
|
|
LogicalResult elemTypeCheck =
|
|
|
|
validateReductionElementType(op, elemType, rewriter);
|
|
|
|
if (failed(elemTypeCheck))
|
|
|
|
return elemTypeCheck;
|
|
|
|
|
|
|
|
Value reduceOp =
|
|
|
|
createReductionOp(loc, elemType, op, operands, *opInfo, rewriter);
|
|
|
|
if (!reduceOp)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to create linalg.generic operation for reduction");
|
|
|
|
|
2024-02-27 00:46:56 +08:00
|
|
|
// If this is aten.norm.Scalar op, then we need to generate another
|
|
|
|
// linalg.generic op that references the first linalg.generic op.
|
|
|
|
if (isa<AtenNormScalarOp>(op)) {
|
|
|
|
AtenNormScalarOp::Adaptor adaptor(operands);
|
|
|
|
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
|
|
|
|
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
|
|
|
|
if (failed(secondReduceOp))
|
|
|
|
return secondReduceOp;
|
|
|
|
reduceOp = *secondReduceOp;
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If this is aten.linalg_vector_norm op, then we need to generate another
|
|
|
|
// linalg.generic op that references the first linalg.generic op.
|
|
|
|
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
|
|
|
|
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
2024-02-27 00:46:56 +08:00
|
|
|
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
if (failed(secondReduceOp))
|
|
|
|
return secondReduceOp;
|
|
|
|
reduceOp = *secondReduceOp;
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
// If it is aten.frobenius_norm.dim op, take the square root of reduceOp as
|
|
|
|
// the final result
|
|
|
|
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op)) {
|
|
|
|
auto halfAttr = rewriter.getFloatAttr(elemType, 0.5);
|
|
|
|
auto exp = rewriter.create<arith::ConstantOp>(loc, halfAttr);
|
|
|
|
reduceOp =
|
|
|
|
createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter);
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, reduceOp);
|
2022-03-11 01:54:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
ConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
target.addIllegalOp<AtenMaxDimOp>();
|
2023-12-05 23:16:35 +08:00
|
|
|
patterns.add<ConvertAtenMinMaxDimOp<AtenMaxDimOp>>(typeConverter, context);
|
|
|
|
target.addIllegalOp<AtenMinDimOp>();
|
|
|
|
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
target.addIllegalOp<AtenSumOp>();
|
2024-04-25 11:15:52 +08:00
|
|
|
target.addIllegalOp<AtenAnyOp>();
|
|
|
|
target.addIllegalOp<AtenAllOp>();
|
2022-03-11 01:54:13 +08:00
|
|
|
target.addIllegalOp<AtenSumDimIntListOp>();
|
2024-04-24 11:14:04 +08:00
|
|
|
target.addIllegalOp<AtenProdOp>();
|
2023-09-06 04:38:51 +08:00
|
|
|
target.addIllegalOp<AtenProdDimIntOp>();
|
2022-03-11 01:54:13 +08:00
|
|
|
target.addIllegalOp<AtenMaxOp>();
|
2023-08-30 01:12:41 +08:00
|
|
|
target.addIllegalOp<AtenMinOp>();
|
2024-02-08 04:34:52 +08:00
|
|
|
target.addIllegalOp<AtenAllDimOp>();
|
2024-02-27 00:46:56 +08:00
|
|
|
target.addIllegalOp<AtenNormScalarOp>();
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
2022-09-08 10:15:36 +08:00
|
|
|
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
2022-03-11 01:54:13 +08:00
|
|
|
patterns.add<ConvertReductionOp>(typeConverter, context);
|
|
|
|
}
|