2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
#include "Utils.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
2022-04-01 16:23:29 +08:00
|
|
|
// Checks the validity of pooling parameters and stores them in the respective
|
|
|
|
// vector.
|
|
|
|
template <typename OpTy>
|
|
|
|
static LogicalResult
|
|
|
|
checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
2022-05-13 20:06:24 +08:00
|
|
|
TypeConverter *typeConverter, bool &ceilMode,
|
|
|
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
2022-04-01 16:23:29 +08:00
|
|
|
SmallVectorImpl<int64_t> &strideInts,
|
2022-03-30 23:14:23 +08:00
|
|
|
SmallVectorImpl<int64_t> &paddingInts) {
|
2022-04-01 16:23:29 +08:00
|
|
|
// Pattern match against the op's original operands, because otherwise we
|
|
|
|
// will get the lowered version of the operands which is harder to pattern
|
|
|
|
// match.
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVector<Value, 2> kernelSizeTorchInt;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) {
|
2022-05-13 20:06:24 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"unimplemented: the kernel size is "
|
|
|
|
"not constructed from ListConstruct");
|
|
|
|
}
|
|
|
|
kernelSizeIntValues = getTypeConvertedValues(
|
|
|
|
rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt);
|
2023-04-27 23:31:36 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "only support constant int strides");
|
2023-04-27 23:31:36 +08:00
|
|
|
// If `stride` is not specified by the user, it is assigned the value of empty
|
|
|
|
// list during import. For such a case, the stride value is the kernel size.
|
|
|
|
// See:
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
|
|
|
if (strideInts.empty()) {
|
|
|
|
if (!matchPattern(op.getKernelSize(),
|
|
|
|
m_TorchListOfConstantInts(strideInts))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "if stride is the empty list, kernel_size must be a list of "
|
|
|
|
"constant ints");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int paddings");
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant bool ceil_mode");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-03-30 23:14:23 +08:00
|
|
|
// Creates a pooling operation based on the type specified by `OpTy` and
|
|
|
|
// arguments passed.
|
|
|
|
template <typename OpTy>
|
|
|
|
static LogicalResult createPoolingOp(
|
|
|
|
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
2022-10-27 08:30:45 +08:00
|
|
|
bool supportNonFPInput, bool ceilMode,
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
2022-03-30 23:14:23 +08:00
|
|
|
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
2022-05-03 21:22:42 +08:00
|
|
|
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
2022-03-30 23:14:23 +08:00
|
|
|
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
2022-04-01 16:23:29 +08:00
|
|
|
Location loc = op->getLoc();
|
|
|
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
2022-10-27 08:30:45 +08:00
|
|
|
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
|
2022-04-01 16:23:29 +08:00
|
|
|
return op->emitError("unimplemented: non-floating point type");
|
|
|
|
|
2022-05-03 21:22:42 +08:00
|
|
|
SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
|
|
|
|
lowPaddingIncludingNC.append(paddingInts);
|
|
|
|
SmallVector<int64_t, 4> highPaddingIncludingNC = lowPaddingIncludingNC;
|
|
|
|
if (ceilMode) {
|
|
|
|
highPaddingIncludingNC[2] += strideInts[0];
|
|
|
|
highPaddingIncludingNC[3] += strideInts[1];
|
|
|
|
}
|
2023-04-25 23:52:46 +08:00
|
|
|
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
2022-03-30 23:14:23 +08:00
|
|
|
paddedInput = torch_to_linalg::getPaddedTensor(
|
2022-05-03 21:22:42 +08:00
|
|
|
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
|
|
|
|
initValue);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
Value N = getDimOp(rewriter, loc, self, 0);
|
|
|
|
Value C = getDimOp(rewriter, loc, self, 1);
|
|
|
|
Value H = getDimOp(rewriter, loc, self, 2);
|
|
|
|
Value W = getDimOp(rewriter, loc, self, 3);
|
|
|
|
|
|
|
|
SmallVector<Value> paddingIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, paddingInts);
|
|
|
|
SmallVector<Value> dilationIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, dilationInts);
|
|
|
|
SmallVector<Value> strideIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
|
|
|
|
|
|
|
Value hOut = torch_to_linalg::getOutputDimForConvOps(
|
|
|
|
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
|
2022-05-03 21:22:42 +08:00
|
|
|
kernelSizeIntValues[0], strideIntValues[0], ceilMode);
|
2022-04-01 16:23:29 +08:00
|
|
|
Value wOut = torch_to_linalg::getOutputDimForConvOps(
|
|
|
|
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
|
2022-05-03 21:22:42 +08:00
|
|
|
kernelSizeIntValues[1], strideIntValues[1], ceilMode);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
// Create output tensor initialized with smallest floating point value.
|
2022-03-30 23:14:23 +08:00
|
|
|
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut});
|
2022-04-01 16:23:29 +08:00
|
|
|
Value outTensorInitialized =
|
2022-03-30 23:14:23 +08:00
|
|
|
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
|
|
|
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
2022-10-18 12:22:53 +08:00
|
|
|
auto shape = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
|
|
|
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(shape), elementType);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
result = rewriter
|
2022-03-30 23:14:23 +08:00
|
|
|
.create<OpTy>(loc, outTensorInitialized.getType(),
|
|
|
|
ValueRange{paddedInput, windowTensor},
|
|
|
|
outTensorInitialized, stridesAttr, dilationAttr)
|
2022-04-01 16:23:29 +08:00
|
|
|
.getResult(0);
|
2022-03-30 23:14:23 +08:00
|
|
|
|
2022-04-01 16:23:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
2022-04-01 16:23:29 +08:00
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
TypeConverter *typeConverter = getTypeConverter();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-04-01 16:23:29 +08:00
|
|
|
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
|
|
|
|
// TODO: Add support for 3D inputs.
|
|
|
|
if (selfRank == 3)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only support 4D input");
|
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
bool ceilMode;
|
|
|
|
SmallVector<Value, 2> kernelSizeIntValues;
|
|
|
|
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
2022-03-30 23:14:23 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int dilations");
|
2022-04-01 16:23:29 +08:00
|
|
|
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
|
2022-05-13 20:06:24 +08:00
|
|
|
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
|
|
|
strideInts, paddingInts)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
|
|
|
|
2022-03-30 23:14:23 +08:00
|
|
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
2022-03-30 23:14:23 +08:00
|
|
|
elementType,
|
2023-05-09 00:17:49 +08:00
|
|
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
|
|
|
/*Negative=*/true));
|
2022-03-30 23:14:23 +08:00
|
|
|
SmallVector<Value, 4> outTensorShape;
|
|
|
|
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
|
|
|
Value maxPool2d, paddedInput;
|
|
|
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
2022-10-27 08:30:45 +08:00
|
|
|
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
2022-05-13 20:06:24 +08:00
|
|
|
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
2022-05-03 21:22:42 +08:00
|
|
|
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// Returns the result of maxpool2d over the input tensor. And the corresponding
|
|
|
|
// indices of the input tensor for the values of the result tensor.
|
|
|
|
//
|
|
|
|
// The result of the maxpool2d operation is calculated using the helper function
|
|
|
|
// written above. For finding the indices, we follow the below method:
|
|
|
|
//
|
|
|
|
// Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will
|
|
|
|
// also be a 4-d tensor. Then:
|
|
|
|
// for i in range(N):
|
|
|
|
// for j in range(C):
|
|
|
|
// for m in range(Hout):
|
|
|
|
// for n in range(Wout):
|
|
|
|
// for p in range(kH):
|
|
|
|
// for r in range(kW):
|
|
|
|
// indexH = m * stride[0] + p * dilation[0]
|
|
|
|
// indexW = n * stride[0] + r * dilation[0]
|
|
|
|
// if paddedInput[i, j, indexH, indexW] ==
|
|
|
|
// maxPool2d[i, j, m, n]:
|
|
|
|
// indices[i, j, m, n] = (indexH - padding[0]) * W +
|
|
|
|
// (indexW - padding[1])
|
|
|
|
//
|
|
|
|
class ConvertAtenMaxPool2dWithIndicesOp
|
|
|
|
: public OpConversionPattern<AtenMaxPool2dWithIndicesOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
2022-03-11 01:54:13 +08:00
|
|
|
Location loc = op->getLoc();
|
2022-05-13 20:06:24 +08:00
|
|
|
TypeConverter *typeConverter = getTypeConverter();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-04-01 16:23:29 +08:00
|
|
|
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
|
|
|
Type elementType = selfType.getElementType();
|
|
|
|
RankedTensorType indicesRankedTensorType =
|
|
|
|
getTypeConverter()
|
|
|
|
->convertType(op->getResult(1).getType())
|
|
|
|
.cast<RankedTensorType>();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-04-01 16:23:29 +08:00
|
|
|
// TODO: Add support for 3D inputs.
|
|
|
|
if (selfType.getRank() == 3)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only support 4D input");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
bool ceilMode;
|
|
|
|
SmallVector<Value, 2> kernelSizeIntValues;
|
|
|
|
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
2022-03-30 23:14:23 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int dilations");
|
2022-04-01 16:23:29 +08:00
|
|
|
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
|
2022-05-13 20:06:24 +08:00
|
|
|
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
|
|
|
strideInts, paddingInts)))
|
2022-04-01 16:23:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
|
|
|
|
2022-03-30 23:14:23 +08:00
|
|
|
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
2022-04-01 16:23:29 +08:00
|
|
|
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
2022-03-11 01:54:13 +08:00
|
|
|
elementType,
|
2023-05-09 00:17:49 +08:00
|
|
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
|
|
|
/*Negative=*/true));
|
2022-03-30 23:14:23 +08:00
|
|
|
Value maxPool2d, paddedInput;
|
|
|
|
SmallVector<Value, 4> outTensorShape;
|
|
|
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
2022-10-27 08:30:45 +08:00
|
|
|
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
2022-05-13 20:06:24 +08:00
|
|
|
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
2022-05-03 21:22:42 +08:00
|
|
|
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
2022-03-30 23:14:23 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
Value cstMinusOne =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
Value indicesTensor =
|
2022-03-30 23:14:23 +08:00
|
|
|
createInitTensor(rewriter, loc, outTensorShape,
|
2022-04-01 16:23:29 +08:00
|
|
|
indicesRankedTensorType.getElementType(), cstMinusOne);
|
|
|
|
|
|
|
|
SmallVector<Value> kernelSize =
|
2022-05-13 20:06:24 +08:00
|
|
|
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
2022-04-01 16:23:29 +08:00
|
|
|
SmallVector<Value> padding =
|
|
|
|
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
|
|
|
SmallVector<Value> dilation =
|
|
|
|
getAsConstantIndexValues(rewriter, loc, dilationInts);
|
|
|
|
SmallVector<Value> stride =
|
|
|
|
getAsConstantIndexValues(rewriter, loc, strideInts);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(kernelSize),
|
|
|
|
indicesRankedTensorType.getElementType());
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
SmallVector<AffineExpr> inputExprs, outputExprs, kernelExprs;
|
|
|
|
for (unsigned i = 0; i < 4; i++) {
|
|
|
|
inputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
outputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
}
|
|
|
|
kernelExprs.push_back(rewriter.getAffineDimExpr(4));
|
|
|
|
kernelExprs.push_back(rewriter.getAffineDimExpr(5));
|
|
|
|
|
|
|
|
// Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH,
|
|
|
|
// and kW, respectively, as described in the algorithm above.
|
|
|
|
SmallVector<AffineMap> indexingMaps =
|
|
|
|
AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs});
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
4, utils::IteratorType::parallel);
|
|
|
|
iteratorTypes.push_back(utils::IteratorType::reduction);
|
|
|
|
iteratorTypes.push_back(utils::IteratorType::reduction);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
// Input format is : [N, C, H, W]
|
|
|
|
Value inputShapeW = getDimOp(rewriter, loc, self, 3);
|
|
|
|
|
|
|
|
Value indicesResult =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, /*resultTensorTypes=*/indicesTensor.getType(),
|
|
|
|
/*inputs=*/ValueRange({maxPool2d, windowTensor}),
|
|
|
|
/*outputs=*/indicesTensor,
|
|
|
|
/*indexingMaps=*/indexingMaps,
|
|
|
|
/*iteratorTypes=*/iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
Value maxVal = args[0], res = args[2];
|
|
|
|
|
|
|
|
Value i = b.create<linalg::IndexOp>(loc, 0);
|
|
|
|
Value j = b.create<linalg::IndexOp>(loc, 1);
|
|
|
|
Value m = b.create<linalg::IndexOp>(loc, 2);
|
|
|
|
Value n = b.create<linalg::IndexOp>(loc, 3);
|
|
|
|
Value p = b.create<linalg::IndexOp>(loc, 4);
|
|
|
|
Value r = b.create<linalg::IndexOp>(loc, 5);
|
|
|
|
|
|
|
|
Value mTimesStride =
|
|
|
|
b.create<arith::MulIOp>(loc, m, stride[0]);
|
|
|
|
Value pTimesDilation =
|
|
|
|
b.create<arith::MulIOp>(loc, p, dilation[0]);
|
|
|
|
Value indexH = b.create<arith::AddIOp>(loc, mTimesStride,
|
|
|
|
pTimesDilation);
|
|
|
|
Value nTimesStride =
|
|
|
|
b.create<arith::MulIOp>(loc, n, stride[1]);
|
|
|
|
Value rTimesDilation =
|
|
|
|
b.create<arith::MulIOp>(loc, r, dilation[1]);
|
|
|
|
Value indexW = b.create<arith::AddIOp>(loc, nTimesStride,
|
|
|
|
rTimesDilation);
|
|
|
|
Value input = b.create<tensor::ExtractOp>(
|
|
|
|
loc, paddedInput, ValueRange{i, j, indexH, indexW});
|
|
|
|
Value pred = b.create<arith::CmpFOp>(
|
|
|
|
loc, arith::CmpFPredicate::OEQ, input, maxVal);
|
|
|
|
|
|
|
|
Value indexHMinusPadding =
|
|
|
|
b.create<arith::SubIOp>(loc, indexH, padding[0]);
|
|
|
|
Value indexWMinusPadding =
|
|
|
|
b.create<arith::SubIOp>(loc, indexW, padding[1]);
|
|
|
|
Value outIndex = b.create<arith::MulIOp>(
|
|
|
|
loc, indexHMinusPadding, inputShapeW);
|
|
|
|
outIndex = b.create<arith::AddIOp>(loc, outIndex,
|
|
|
|
indexWMinusPadding);
|
|
|
|
Value result = b.create<arith::SelectOp>(
|
2022-04-22 01:10:04 +08:00
|
|
|
loc, pred, castIndexToInt64(b, loc, outIndex), res);
|
2022-04-01 16:23:29 +08:00
|
|
|
|
|
|
|
Value predInvalidIndex = b.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, res, cstMinusOne);
|
|
|
|
Value out = b.create<arith::SelectOp>(loc, predInvalidIndex,
|
|
|
|
result, res);
|
|
|
|
|
|
|
|
b.create<linalg::YieldOp>(loc, out);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
Type maxPool2dResultType =
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType());
|
|
|
|
Type indicesResultType =
|
|
|
|
getTypeConverter()->convertType(op->getResult(1).getType());
|
|
|
|
Value outMaxpool2d =
|
|
|
|
rewriter.create<tensor::CastOp>(loc, maxPool2dResultType, maxPool2d);
|
|
|
|
Value outIndices =
|
|
|
|
rewriter.create<tensor::CastOp>(loc, indicesResultType, indicesResult);
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {outMaxpool2d, outIndices});
|
2022-03-11 01:54:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-30 23:14:23 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
|
|
|
Location loc = op->getLoc();
|
2022-05-13 20:06:24 +08:00
|
|
|
TypeConverter *typeConverter = getTypeConverter();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-03-30 23:14:23 +08:00
|
|
|
|
|
|
|
Type inputElementType =
|
|
|
|
self.getType().cast<RankedTensorType>().getElementType();
|
|
|
|
Type resultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
Type resultElementType =
|
|
|
|
resultType.cast<RankedTensorType>().getElementType();
|
|
|
|
|
2022-05-03 21:22:42 +08:00
|
|
|
bool ceilMode;
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVector<Value, 2> kernelSizeIntValues;
|
|
|
|
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1};
|
2022-03-30 23:14:23 +08:00
|
|
|
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
|
2022-05-13 20:06:24 +08:00
|
|
|
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
|
|
|
strideInts, paddingInts)))
|
2022-03-30 23:14:23 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
|
|
|
|
|
|
|
// TODO: Add support for count_include_pad equal to `False`.
|
|
|
|
bool countIncludePad;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getCountIncludePad(),
|
2022-03-30 23:14:23 +08:00
|
|
|
m_TorchConstantBool(&countIncludePad)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "count_include_pad must be a constant");
|
|
|
|
if (!countIncludePad) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: count_include_pad is expected to be true");
|
|
|
|
}
|
|
|
|
|
|
|
|
// `sumPool2d` contains the result of sumpool2d operation over the input.
|
|
|
|
Value sumPool2d, paddedInput;
|
|
|
|
SmallVector<Value, 4> outTensorShape;
|
|
|
|
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
|
2022-10-27 08:30:45 +08:00
|
|
|
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
2022-05-13 20:06:24 +08:00
|
|
|
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
2022-03-30 23:14:23 +08:00
|
|
|
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
|
|
|
|
sumPool2d)))
|
|
|
|
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
|
|
|
|
|
|
|
|
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
|
|
|
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
2022-12-08 04:20:41 +08:00
|
|
|
Value divisor = op.getDivisorOverride().getType().isa<Torch::NoneType>()
|
2022-03-30 23:14:23 +08:00
|
|
|
? kHtimeskW
|
2022-12-08 04:20:41 +08:00
|
|
|
: adaptor.getDivisorOverride();
|
2022-03-30 23:14:23 +08:00
|
|
|
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
2022-03-30 23:14:23 +08:00
|
|
|
SmallVector<AffineMap> indexingMapsAvg(2,
|
|
|
|
rewriter.getMultiDimIdentityMap(4));
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
|
|
|
4, utils::IteratorType::parallel);
|
2022-03-30 23:14:23 +08:00
|
|
|
|
|
|
|
Value avgPool2d =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, outputTensor.getType(), sumPool2d, outputTensor,
|
|
|
|
/*indexingMaps=*/indexingMapsAvg,
|
|
|
|
/*iteratorTypes=*/iteratorTypesAvg,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
Value avg;
|
|
|
|
if (resultElementType.isa<mlir::IntegerType>())
|
|
|
|
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
|
|
|
else if (resultElementType.isa<mlir::FloatType>())
|
|
|
|
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
|
|
|
b.create<linalg::YieldOp>(loc, avg);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
ConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
target.addIllegalOp<AtenMaxPool2dOp>();
|
|
|
|
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
2022-04-01 16:23:29 +08:00
|
|
|
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
|
|
|
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
2022-03-30 23:14:23 +08:00
|
|
|
target.addIllegalOp<AtenAvgPool2dOp>();
|
|
|
|
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|