2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
#include "Utils.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
#include "llvm/ADT/APSInt.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
|
|
|
|
// op, producing two output buffers.
|
|
|
|
//
|
|
|
|
// The first output buffer contains the maximum value found. It is initialized
|
|
|
|
// to the minimum representable value of the input element type.
|
|
|
|
//
|
|
|
|
// The second output buffer contains the index of the found maximum value. It is
|
|
|
|
// initialized to 0 and is resulting integer type.
|
|
|
|
//
|
|
|
|
// The indexed_generic op updates both the maximum value and index if the
|
|
|
|
// current value exceeds the running max.
|
|
|
|
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = maxDimOp.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2022-03-11 01:54:13 +08:00
|
|
|
RankedTensorType valResultType =
|
|
|
|
getTypeConverter()
|
|
|
|
->convertType(maxDimOp.getResult(0).getType())
|
|
|
|
.cast<RankedTensorType>();
|
|
|
|
RankedTensorType idxResultType =
|
|
|
|
getTypeConverter()
|
|
|
|
->convertType(maxDimOp.getResult(1).getType())
|
|
|
|
.cast<RankedTensorType>();
|
|
|
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
|
|
|
Type idxElementType = idxResultType.getElementType();
|
|
|
|
if (!idxElementType.isa<IntegerType>())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
maxDimOp,
|
|
|
|
"aten.max_dim to linalg.* requires integer-like result type");
|
|
|
|
|
|
|
|
bool keepDim = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(maxDimOp.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
maxDimOp, "aten.max_dim requires boolean value for keepdim");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(maxDimOp.getDim(), m_TorchConstantInt(&dim)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
if (!isValidDim(dim, inputType.getRank()))
|
|
|
|
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
|
|
|
|
|
|
|
|
Type inElementType = inputType.getElementType();
|
|
|
|
if (!inElementType.isa<mlir::FloatType>()) {
|
2023-02-06 19:52:04 +08:00
|
|
|
if (inElementType.isa<mlir::IntegerType>()) {
|
|
|
|
auto integerTy = maxDimOp.getSelf()
|
|
|
|
.getType()
|
|
|
|
.cast<BaseTensorType>()
|
|
|
|
.getDtype()
|
|
|
|
.dyn_cast<mlir::IntegerType>();
|
|
|
|
if (integerTy.isUnsigned())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
maxDimOp, "aten.max_dim to linalg.* requires input element type "
|
|
|
|
"to be signed in case of integer");
|
|
|
|
} else {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
|
|
|
|
"input element type");
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Constant op to account for the reduction along dim.
|
|
|
|
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
|
|
|
if (dim != i) {
|
|
|
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
|
|
|
resultShape.push_back(currentDimSize);
|
|
|
|
} else if (keepDim)
|
|
|
|
resultShape.push_back(c1);
|
|
|
|
}
|
|
|
|
// First fill the output buffer for the index.
|
|
|
|
Value filledTensorIdx =
|
|
|
|
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
|
|
|
|
|
|
|
|
// Second fill the output buffer for the running max.
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensorMax = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(resultShape), inElementType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2023-02-06 19:52:04 +08:00
|
|
|
Value fillValueMax;
|
|
|
|
if (inElementType.isa<mlir::FloatType>()) {
|
|
|
|
fillValueMax = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc,
|
|
|
|
rewriter.getFloatAttr(
|
|
|
|
inElementType,
|
2023-05-09 00:17:49 +08:00
|
|
|
APFloat::getInf(
|
2023-02-06 19:52:04 +08:00
|
|
|
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
2023-05-09 00:17:49 +08:00
|
|
|
/*Negative=*/true)));
|
2023-02-06 19:52:04 +08:00
|
|
|
} else {
|
|
|
|
fillValueMax = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(
|
|
|
|
inElementType,
|
|
|
|
APSInt::getSignedMinValue(
|
|
|
|
inElementType.cast<mlir::IntegerType>().getWidth())));
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value filledTensorMax =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
|
|
|
|
.result();
|
|
|
|
|
|
|
|
// Create the affine expressions that will be used to
|
|
|
|
// iterate over the input and output tensors.
|
|
|
|
// Here we also set the type of iterator: parallel or reduction.
|
|
|
|
SmallVector<AffineExpr> exprs;
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes;
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<AffineExpr> resultExprs;
|
2022-11-29 20:33:31 +08:00
|
|
|
for (auto size :
|
|
|
|
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
2022-03-11 01:54:13 +08:00
|
|
|
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
|
|
|
|
|
|
|
if (unsigned(dim) == size.index()) {
|
2022-11-17 06:40:36 +08:00
|
|
|
iteratorTypes.push_back(utils::IteratorType::reduction);
|
2022-03-11 01:54:13 +08:00
|
|
|
// If `keepDim`, create affine map to the first element
|
|
|
|
// in the current dimension.
|
|
|
|
if (keepDim)
|
|
|
|
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
|
|
} else {
|
2022-11-17 06:40:36 +08:00
|
|
|
iteratorTypes.push_back(utils::IteratorType::parallel);
|
2022-03-11 01:54:13 +08:00
|
|
|
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
|
|
|
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc,
|
|
|
|
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
|
|
|
|
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
|
|
|
|
iteratorTypes,
|
|
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
|
|
|
ValueRange blockArgs) {
|
|
|
|
Value newValue = blockArgs[0];
|
|
|
|
Value oldValue = blockArgs[1];
|
|
|
|
Value oldIndex = blockArgs[2];
|
|
|
|
|
|
|
|
Value newIndex = rewriter.create<arith::IndexCastOp>(
|
|
|
|
nestedLoc, oldIndex.getType(),
|
|
|
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
|
|
|
|
2023-02-06 19:52:04 +08:00
|
|
|
Value resultMax, predicate;
|
|
|
|
if (inElementType.isa<mlir::FloatType>()) {
|
|
|
|
resultMax =
|
|
|
|
rewriter.create<arith::MaxFOp>(nestedLoc, newValue, oldValue);
|
|
|
|
predicate = rewriter.create<arith::CmpFOp>(
|
|
|
|
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
|
|
|
} else {
|
|
|
|
resultMax =
|
|
|
|
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
|
|
|
|
predicate = rewriter.create<arith::CmpIOp>(
|
|
|
|
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
auto resultIndex = rewriter.create<arith::SelectOp>(
|
|
|
|
nestedLoc, predicate, newIndex, oldIndex);
|
|
|
|
nestedBuilder.create<linalg::YieldOp>(
|
|
|
|
nestedLoc, ValueRange({resultMax, resultIndex}));
|
|
|
|
});
|
|
|
|
|
|
|
|
// This cast is required to fix the shape in the case of keepDim=True
|
|
|
|
Value maxValuesCast = rewriter.create<tensor::CastOp>(
|
|
|
|
loc, valResultType, linalgOp.getResult(0));
|
|
|
|
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
|
|
|
linalgOp.getResult(1));
|
|
|
|
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|
|
|
Operation *op, Type elementType) {
|
2022-03-11 01:54:13 +08:00
|
|
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
|
|
|
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
|
|
|
|
|
|
|
if (isa<AtenMaxOp>(op)) {
|
|
|
|
if (elementType.isa<mlir::FloatType>())
|
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getFloatAttr(
|
|
|
|
elementType,
|
2023-05-09 00:17:49 +08:00
|
|
|
APFloat::getInf(
|
2022-03-11 01:54:13 +08:00
|
|
|
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
|
|
|
/*Negative=*/true)));
|
|
|
|
else if (elementType.isa<mlir::IntegerType>() &&
|
|
|
|
elementType.getIntOrFloatBitWidth() != 8)
|
|
|
|
return b.create<arith::ConstantOp>(
|
|
|
|
loc, b.getIntegerAttr(elementType,
|
|
|
|
APSInt::getSignedMinValue(
|
|
|
|
elementType.getIntOrFloatBitWidth())));
|
|
|
|
}
|
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
|
|
|
|
|
|
|
op->emitError("unimplemented lowering in createInitElementForReduceOp");
|
2022-03-11 01:54:13 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|
|
|
ValueRange payloadArgs,
|
|
|
|
Operation *op,
|
|
|
|
ArrayRef<Value> operands,
|
|
|
|
Type resultElementType) {
|
2022-03-11 01:54:13 +08:00
|
|
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
|
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
if (resultElementType.isa<mlir::FloatType>())
|
|
|
|
return b.create<arith::AddFOp>(loc, self, result);
|
|
|
|
else if (resultElementType.isa<mlir::IntegerType>())
|
|
|
|
return b.create<arith::AddIOp>(loc, self, result);
|
|
|
|
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
|
|
|
Value self =
|
|
|
|
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
if (resultElementType.isa<mlir::FloatType>())
|
|
|
|
return b.create<arith::MaxFOp>(loc, self, result);
|
|
|
|
else if (resultElementType.isa<mlir::IntegerType>()) {
|
2022-12-08 04:20:41 +08:00
|
|
|
IntegerType intType = max.getSelf()
|
2022-03-11 01:54:13 +08:00
|
|
|
.getType()
|
|
|
|
.cast<BaseTensorType>()
|
|
|
|
.getDtype()
|
|
|
|
.dyn_cast<mlir::IntegerType>();
|
|
|
|
if (intType.isUnsigned())
|
|
|
|
return b.create<arith::MaxUIOp>(loc, self, result);
|
|
|
|
if (intType.isSigned())
|
|
|
|
return b.create<arith::MaxSIOp>(loc, self, result);
|
|
|
|
}
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
|
|
|
// This creates payload for only the first of the two linalg.generic ops.
|
|
|
|
// TODO: Short-circuit operations if `ord` is zero or one.
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
2022-08-16 14:54:45 +08:00
|
|
|
auto abs = b.create<math::AbsFOp>(loc, self);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
2022-12-08 04:20:41 +08:00
|
|
|
Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
|
|
|
return b.create<arith::AddFOp>(loc, pow, result);
|
2022-09-08 10:15:36 +08:00
|
|
|
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
|
|
|
|
Value elem = payloadArgs[0];
|
|
|
|
Value result = payloadArgs[1];
|
|
|
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
|
|
|
auto abs = b.create<math::AbsFOp>(loc, self);
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr twoAttr = b.getFloatAttr(resultElementType, 2.0);
|
2022-09-08 10:15:36 +08:00
|
|
|
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
|
|
|
|
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
|
|
|
return b.create<arith::AddFOp>(loc, pow, result);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
|
2022-03-11 01:54:13 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertReductionOp : public ConversionPattern {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
private:
|
|
|
|
/// Given a reduction operation that has the `keepdim` attribute and the
|
|
|
|
/// (optional) `dim` attribute, return the source tensor operand and the
|
|
|
|
/// literal values of the attributes or failure otherwise.
|
|
|
|
template <typename T>
|
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo>
|
|
|
|
computeReductionOpInfoForDimVariantOp(
|
|
|
|
T op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
|
|
|
typename T::Adaptor adaptor(operands);
|
2022-12-08 04:20:41 +08:00
|
|
|
opInfo.tensorOperand = adaptor.getSelf();
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"`keepdim` must be a constant bool");
|
|
|
|
|
|
|
|
SmallVector<int64_t> dimList;
|
2022-07-28 22:24:24 +08:00
|
|
|
bool isNoneOrEmptyDimList =
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getDim().getType().template isa<Torch::NoneType>();
|
|
|
|
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// Fix negative dimensions, if any, before adding to the list.
|
|
|
|
for (int64_t dim : dimList) {
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
// Drop invalid dimensions
|
|
|
|
if (isValidDim(dim, inputType.getRank()))
|
|
|
|
opInfo.dimSet.insert(dim);
|
|
|
|
}
|
2022-07-28 22:24:24 +08:00
|
|
|
if (dimList.empty())
|
|
|
|
isNoneOrEmptyDimList = true;
|
|
|
|
} else if (!isNoneOrEmptyDimList) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "`dim` argument must be a constant int list or None");
|
|
|
|
}
|
|
|
|
if (isNoneOrEmptyDimList) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If no dimensions were specified, reduce along all dimensions
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
|
|
opInfo.dimSet.insert(i);
|
|
|
|
}
|
|
|
|
|
|
|
|
return opInfo;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Given a reduction operation, return the source tensor operand and the
|
|
|
|
/// literal values of the `keepdim` and `dim` attributes, if any, or failure
|
|
|
|
/// otherwise.
|
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo>
|
|
|
|
computeReductionOpInfo(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
|
|
|
|
|
|
|
if (isa<AtenMaxOp, AtenSumOp>(op)) {
|
|
|
|
opInfo.tensorOperand = operands[0];
|
|
|
|
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
|
|
|
// input tensor.
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
|
|
opInfo.dimSet.insert(i);
|
|
|
|
|
|
|
|
return opInfo;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto sumOp = dyn_cast<AtenSumDimIntListOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter);
|
|
|
|
|
|
|
|
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
|
|
|
|
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate a linalg.generic operation for pointwise exponentiation of each
|
|
|
|
/// element.
|
|
|
|
Value createElementwiseExp(Location loc, Type elemType, Value exponent,
|
|
|
|
Value inputTensor,
|
|
|
|
const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
bool err = false;
|
|
|
|
auto powBodyBuilder = [&](OpBuilder &builder, Location loc,
|
|
|
|
ValueRange payloadArgs) {
|
|
|
|
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType);
|
|
|
|
auto result = builder.create<math::PowFOp>(loc, elem, exponent);
|
|
|
|
if (result)
|
|
|
|
builder.create<linalg::YieldOp>(loc, Value{result});
|
|
|
|
err = !result;
|
|
|
|
};
|
|
|
|
|
|
|
|
Value powOp = torch_to_linalg::createElementwiseLinalgGeneric(
|
|
|
|
rewriter, loc, {inputTensor}, elemType, powBodyBuilder);
|
|
|
|
return err ? Value{} : powOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
FailureOr<Value> createSecondReductionForVectorNormOp(
|
|
|
|
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
|
|
|
|
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
// Cast `ord` to float so that we can readily pass it math.powf.
|
|
|
|
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
|
|
|
|
|
|
|
|
// TODO: Add support for ord = {0, +inf, -inf}.
|
|
|
|
auto epsilon = 1e-5;
|
|
|
|
auto ordLiteral = 0.0;
|
|
|
|
if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) &&
|
|
|
|
fabs(ordLiteral) < epsilon)
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm");
|
|
|
|
|
|
|
|
if (std::isinf(ordLiteral))
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf");
|
|
|
|
|
|
|
|
// Raise each summed value to the inverse of the order of the norm.
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr oneAttr = rewriter.getFloatAttr(elemType, 1.0);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto oneValue = rewriter.create<arith::ConstantOp>(loc, oneAttr);
|
|
|
|
auto inverseOrdValue =
|
|
|
|
rewriter.create<arith::DivFOp>(loc, oneValue, ordValue);
|
|
|
|
|
|
|
|
// Use the results of the first reduction operation from above to generate
|
|
|
|
// a second reduction operation.
|
|
|
|
Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue,
|
|
|
|
firstReduction, opInfo, rewriter);
|
|
|
|
if (!reduceOp)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to create linalg.generic operation for element-wise "
|
|
|
|
"exponentiation");
|
|
|
|
|
|
|
|
return reduceOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generate a linalg.generic operation for a reduction.
|
|
|
|
Value createReductionOp(Location loc, Type elemType, Operation *op,
|
|
|
|
ArrayRef<Value> operands,
|
|
|
|
const torch_to_linalg::ReductionOpInfo &opInfo,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
bool err = false;
|
|
|
|
auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc,
|
|
|
|
ValueRange payloadArgs) {
|
|
|
|
Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs,
|
|
|
|
op, operands, elemType);
|
|
|
|
if (result)
|
|
|
|
builder.create<linalg::YieldOp>(loc, result);
|
|
|
|
err = !result;
|
|
|
|
};
|
|
|
|
|
|
|
|
Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType);
|
|
|
|
Value reduceOp = torch_to_linalg::createReductionLinalgGeneric(
|
|
|
|
rewriter, loc, opInfo, initElem, reductionBodyBuilder);
|
|
|
|
return err ? Value{} : reduceOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Depending on the operation, check validity of the result's element type.
|
|
|
|
LogicalResult
|
|
|
|
validateReductionElementType(Operation *op, Type elemType,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-09-08 10:15:36 +08:00
|
|
|
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
|
|
|
|
!elemType.isa<mlir::FloatType>())
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only float types are valid for vector norm ops");
|
|
|
|
// No checks for all other reduction operations
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
public:
|
|
|
|
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
|
|
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
|
|
|
context) {}
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid operand or result types to use with linalg on tensors");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
FailureOr<torch_to_linalg::ReductionOpInfo> opInfo =
|
|
|
|
computeReductionOpInfo(op, operands, rewriter);
|
|
|
|
if (failed(opInfo))
|
|
|
|
return opInfo;
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
auto resultType = getTypeConverter()
|
|
|
|
->convertType(op->getResult(0).getType())
|
|
|
|
.cast<RankedTensorType>();
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
Type elemType = resultType.getElementType();
|
|
|
|
LogicalResult elemTypeCheck =
|
|
|
|
validateReductionElementType(op, elemType, rewriter);
|
|
|
|
if (failed(elemTypeCheck))
|
|
|
|
return elemTypeCheck;
|
|
|
|
|
|
|
|
Value reduceOp =
|
|
|
|
createReductionOp(loc, elemType, op, operands, *opInfo, rewriter);
|
|
|
|
if (!reduceOp)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to create linalg.generic operation for reduction");
|
|
|
|
|
|
|
|
// If this is aten.linalg_vector_norm op, then we need to generate another
|
|
|
|
// linalg.generic op that references the first linalg.generic op.
|
|
|
|
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
|
|
|
|
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
|
|
|
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
if (failed(secondReduceOp))
|
|
|
|
return secondReduceOp;
|
|
|
|
reduceOp = *secondReduceOp;
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
// If it is aten.frobenius_norm.dim op, take the square root of reduceOp as
|
|
|
|
// the final result
|
|
|
|
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op)) {
|
|
|
|
auto halfAttr = rewriter.getFloatAttr(elemType, 0.5);
|
|
|
|
auto exp = rewriter.create<arith::ConstantOp>(loc, halfAttr);
|
|
|
|
reduceOp =
|
|
|
|
createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter);
|
|
|
|
}
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, reduceOp);
|
2022-03-11 01:54:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
ConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
target.addIllegalOp<AtenMaxDimOp>();
|
|
|
|
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<AtenSumOp>();
|
|
|
|
target.addIllegalOp<AtenSumDimIntListOp>();
|
|
|
|
target.addIllegalOp<AtenMaxOp>();
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
2022-09-08 10:15:36 +08:00
|
|
|
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
2022-03-11 01:54:13 +08:00
|
|
|
patterns.add<ConvertReductionOp>(typeConverter, context);
|
|
|
|
}
|