mirror of https://github.com/llvm/torch-mlir
merge indices and pooling computation into one linalg generic op
parent
04740824ae
commit
ae1fa20290
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
@ -208,8 +209,6 @@ static LogicalResult createPoolingOp(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T> struct DimensionTraits {};
|
template <typename T> struct DimensionTraits {};
|
||||||
|
|
||||||
template <> struct DimensionTraits<AtenMaxPool1dOp> {
|
template <> struct DimensionTraits<AtenMaxPool1dOp> {
|
||||||
|
@ -238,247 +237,156 @@ template <>
|
||||||
struct DimensionTraits<AtenMaxPool3dWithIndicesOp>
|
struct DimensionTraits<AtenMaxPool3dWithIndicesOp>
|
||||||
: DimensionTraits<AtenMaxPool3dOp> {};
|
: DimensionTraits<AtenMaxPool3dOp> {};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
LogicalResult createCustomMaxPoolingOp(
|
||||||
|
OpTy &op, typename OpTy::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, const TypeConverter *typeConverter,
|
||||||
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||||
|
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
||||||
|
SmallVectorImpl<int64_t> &dilationInts, bool ceilMode,
|
||||||
|
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput,
|
||||||
|
ValueRange &results,
|
||||||
|
std::function<Value(OpBuilder &builder, Location loc,
|
||||||
|
ValueRange iteratorDims)> &&indicesComputation =
|
||||||
|
nullptr) {
|
||||||
|
constexpr bool withIndices =
|
||||||
|
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
|
||||||
|
AtenMaxPool3dWithIndicesOp>::value;
|
||||||
|
constexpr int64_t Dim = DimensionTraits<OpTy>::Dim;
|
||||||
|
|
||||||
|
if (withIndices && !indicesComputation) {
|
||||||
|
return op->emitError("need to provide indices computation functor for "
|
||||||
|
"lowering maxpool with indices op");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
|
elementType,
|
||||||
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
|
/*Negative=*/true));
|
||||||
|
|
||||||
|
Value initValue =
|
||||||
|
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
||||||
|
|
||||||
|
paddedInput = padInputTensor(op, rewriter, self, ceilMode, Dim, strideInts,
|
||||||
|
paddingInts, initValue);
|
||||||
|
|
||||||
|
auto maxOutputInitialized = computeOutputTensor(
|
||||||
|
op, rewriter, self, Dim, ceilMode, strideInts, paddingInts, dilationInts,
|
||||||
|
kernelSizeIntValues, outTensorShape, initValue);
|
||||||
|
|
||||||
|
auto shape =
|
||||||
|
castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues);
|
||||||
|
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
op->getLoc(), getAsOpFoldResult(shape), elementType);
|
||||||
|
|
||||||
|
MLIRContext *context = rewriter.getContext();
|
||||||
|
|
||||||
|
SmallVector<mlir::AffineExpr> inputIndexing(
|
||||||
|
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)});
|
||||||
|
SmallVector<mlir::AffineExpr> maxOutputIndexing = inputIndexing;
|
||||||
|
SmallVector<mlir::AffineExpr> kernelIndexing;
|
||||||
|
std::optional<SmallVector<mlir::AffineExpr>> indicesIndexing;
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
mlir::AffineExpr poolingDim = rewriter.getAffineDimExpr(i + 2);
|
||||||
|
mlir::AffineExpr kernelDim = rewriter.getAffineDimExpr(i + 2 + Dim);
|
||||||
|
inputIndexing.push_back(
|
||||||
|
poolingDim * getAffineConstantExpr(strideInts[i], context) +
|
||||||
|
kernelDim * getAffineConstantExpr(dilationInts[i], context));
|
||||||
|
maxOutputIndexing.push_back(poolingDim);
|
||||||
|
kernelIndexing.push_back(kernelDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iteratorTypes =
|
||||||
|
SmallVector<utils::IteratorType>(2 + Dim, utils::IteratorType::parallel);
|
||||||
|
iteratorTypes.append(Dim, utils::IteratorType::reduction);
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
mlir::AffineMap::get(2 + Dim * 2, 0, inputIndexing, context),
|
||||||
|
mlir::AffineMap::get(2 + Dim * 2, 0, kernelIndexing, context),
|
||||||
|
mlir::AffineMap::get(2 + Dim * 2, 0, maxOutputIndexing, context)};
|
||||||
|
SmallVector<mlir::Value> outputs({maxOutputInitialized});
|
||||||
|
SmallVector<mlir::Type> outTypes({maxOutputInitialized.getType()});
|
||||||
|
|
||||||
|
if constexpr (withIndices) {
|
||||||
|
// Indices tensor has same indexing/shape as max value tensor.
|
||||||
|
indexingMaps.push_back(
|
||||||
|
mlir::AffineMap::get(2 + Dim * 2, 0, maxOutputIndexing, context));
|
||||||
|
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
||||||
|
typeConverter->convertType(op->getResult(1).getType()));
|
||||||
|
Value cstMinusOne = rewriter.create<arith::ConstantOp>(
|
||||||
|
op->getLoc(), rewriter.getI64IntegerAttr(-1));
|
||||||
|
Value indicesOutputInitialized =
|
||||||
|
createInitTensor(rewriter, op->getLoc(), outTensorShape,
|
||||||
|
indicesRankedTensorType.getElementType(), cstMinusOne);
|
||||||
|
outputs.push_back(indicesOutputInitialized);
|
||||||
|
outTypes.push_back(indicesOutputInitialized.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
results =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
/*result_types=*/outTypes,
|
||||||
|
/*operands=*/ValueRange({paddedInput, windowTensor}),
|
||||||
|
/*outputs=*/outputs,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value currentVal = args[0], accMaxValue = args[2];
|
||||||
|
if constexpr (withIndices) {
|
||||||
|
Value curIndex = args[3];
|
||||||
|
SmallVector<Value> iterators;
|
||||||
|
for (int i = 0; i < Dim * 2; i++) {
|
||||||
|
iterators.push_back(b.create<linalg::IndexOp>(loc, i + 2));
|
||||||
|
}
|
||||||
|
Value pred = b.create<arith::CmpFOp>(
|
||||||
|
loc, arith::CmpFPredicate::UGT, currentVal, accMaxValue);
|
||||||
|
// Consider the corner case: the max pooling result is same as
|
||||||
|
// padding value, which is -inf. We should return the first
|
||||||
|
// index of pooling window but not -1.
|
||||||
|
pred = b.create<arith::OrIOp>(
|
||||||
|
loc, pred,
|
||||||
|
b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, curIndex,
|
||||||
|
b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getI64IntegerAttr(-1))));
|
||||||
|
ValueRange outResults =
|
||||||
|
b.create<mlir::scf::IfOp>(
|
||||||
|
loc, pred,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
SmallVector<Value> curResults{currentVal};
|
||||||
|
if constexpr (withIndices) {
|
||||||
|
curResults.push_back(
|
||||||
|
indicesComputation(b, loc, iterators));
|
||||||
|
}
|
||||||
|
b.create<scf::YieldOp>(loc, curResults);
|
||||||
|
},
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(loc,
|
||||||
|
args.drop_front(/*n=*/2));
|
||||||
|
})
|
||||||
|
->getResults();
|
||||||
|
b.create<linalg::YieldOp>(loc, outResults);
|
||||||
|
} else {
|
||||||
|
Value max_result =
|
||||||
|
b.create<arith::MaximumFOp>(loc, currentVal, accMaxValue);
|
||||||
|
b.create<linalg::YieldOp>(loc, max_result);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
->getResults();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
|
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
|
||||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
static const bool withIndices =
|
|
||||||
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
|
|
||||||
AtenMaxPool3dWithIndicesOp>::value;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
|
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
|
||||||
|
|
||||||
LogicalResult createPoolingMax3D(OpTy &op, typename OpTy::Adaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
|
||||||
SmallVectorImpl<int64_t> &strideInts,
|
|
||||||
SmallVectorImpl<int64_t> &paddingInts,
|
|
||||||
SmallVectorImpl<int64_t> &dilationInts,
|
|
||||||
bool ceilMode,
|
|
||||||
SmallVectorImpl<Value> &outTensorShape,
|
|
||||||
Value &paddedInput, Value &poolingOp) const {
|
|
||||||
static_assert(Dim == 3, "op must be MaxPool3d or MaxPool3dWithIndices");
|
|
||||||
Value self = adaptor.getSelf();
|
|
||||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
|
||||||
elementType,
|
|
||||||
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
|
||||||
/*Negative=*/true));
|
|
||||||
Value initValue =
|
|
||||||
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
|
||||||
|
|
||||||
paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, strideInts,
|
|
||||||
paddingInts, initValue);
|
|
||||||
|
|
||||||
auto outTensorInitialized = computeOutputTensor(
|
|
||||||
op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts,
|
|
||||||
kernelSizeIntValues, outTensorShape, initValue);
|
|
||||||
|
|
||||||
auto shape =
|
|
||||||
castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues);
|
|
||||||
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
|
||||||
op->getLoc(), getAsOpFoldResult(shape), elementType);
|
|
||||||
|
|
||||||
MLIRContext *context = rewriter.getContext();
|
|
||||||
|
|
||||||
auto mapInput = mlir::AffineMap::get(
|
|
||||||
8, 0,
|
|
||||||
{
|
|
||||||
rewriter.getAffineDimExpr(0), // n
|
|
||||||
rewriter.getAffineDimExpr(1), // c
|
|
||||||
// dim_d * stride_d + kernal_d * dilation_d
|
|
||||||
rewriter.getAffineDimExpr(2) *
|
|
||||||
getAffineConstantExpr(strideInts[0], context) +
|
|
||||||
rewriter.getAffineDimExpr(5) *
|
|
||||||
getAffineConstantExpr(dilationInts[0], context),
|
|
||||||
// dim_h * stride_h + kernal_h * dilation_h
|
|
||||||
rewriter.getAffineDimExpr(3) *
|
|
||||||
getAffineConstantExpr(strideInts[1], context) +
|
|
||||||
rewriter.getAffineDimExpr(6) *
|
|
||||||
getAffineConstantExpr(dilationInts[1], context),
|
|
||||||
// dim_w * stride_w + kernal_w * dilation_w
|
|
||||||
rewriter.getAffineDimExpr(4) *
|
|
||||||
getAffineConstantExpr(strideInts[2], context) +
|
|
||||||
rewriter.getAffineDimExpr(7) *
|
|
||||||
getAffineConstantExpr(dilationInts[2], context),
|
|
||||||
},
|
|
||||||
context);
|
|
||||||
auto mapKernel =
|
|
||||||
mlir::AffineMap::get(8, 0,
|
|
||||||
{
|
|
||||||
rewriter.getAffineDimExpr(5), // kd
|
|
||||||
rewriter.getAffineDimExpr(6), // kh
|
|
||||||
rewriter.getAffineDimExpr(7) // kw
|
|
||||||
},
|
|
||||||
context);
|
|
||||||
auto mapOutput = mlir::AffineMap::get(
|
|
||||||
8, 0,
|
|
||||||
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1),
|
|
||||||
rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3),
|
|
||||||
rewriter.getAffineDimExpr(4)},
|
|
||||||
context);
|
|
||||||
auto iteratorTypes =
|
|
||||||
SmallVector<utils::IteratorType>(5, utils::IteratorType::parallel);
|
|
||||||
iteratorTypes.append(3, utils::IteratorType::reduction);
|
|
||||||
SmallVector<AffineMap> indexingMaps = {mapInput, mapKernel, mapOutput};
|
|
||||||
poolingOp = rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
op->getLoc(),
|
|
||||||
/* result types */ outTensorInitialized.getType(),
|
|
||||||
/* operands */ ValueRange({paddedInput, windowTensor}),
|
|
||||||
/* outputs */ outTensorInitialized,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value currentVal = args[0], accMaxValue = args[2];
|
|
||||||
Value max_result = b.create<arith::MaximumFOp>(
|
|
||||||
loc, currentVal, accMaxValue);
|
|
||||||
b.create<linalg::YieldOp>(loc, max_result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the corresponding indices of the input tensor for the max pooling
|
|
||||||
// result tensor.
|
|
||||||
//
|
|
||||||
// For finding the indices, we follow the below method:
|
|
||||||
//
|
|
||||||
// Take maxpool2d as an example to illustrate. Let's say the input tensor is a
|
|
||||||
// 4-d tensor. The maxpool2d and indices will also be a 4-d tensor. Then:
|
|
||||||
// for i in range(N):
|
|
||||||
// for j in range(C):
|
|
||||||
// for m in range(Hout):
|
|
||||||
// for n in range(Wout):
|
|
||||||
// for p in range(kH):
|
|
||||||
// for r in range(kW):
|
|
||||||
// indexH = m * stride[0] + p * dilation[0]
|
|
||||||
// indexW = n * stride[0] + r * dilation[0]
|
|
||||||
// if paddedInput[i, j, indexH, indexW] ==
|
|
||||||
// maxPool2d[i, j, m, n]:
|
|
||||||
// indices[i, j, m, n] =
|
|
||||||
// (indexH - padding[0]) * W +
|
|
||||||
// (indexW - padding[1])
|
|
||||||
//
|
|
||||||
LogicalResult
|
|
||||||
computeMaxPoolingIndices(Value maxPool, Value paddedInput, OpTy &op,
|
|
||||||
typename OpTy::Adaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
SmallVectorImpl<Value> &outTensorShape,
|
|
||||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
|
||||||
SmallVectorImpl<int64_t> &strideInts,
|
|
||||||
SmallVectorImpl<int64_t> &paddingInts,
|
|
||||||
SmallVectorImpl<int64_t> &dilationInts, int64_t rank,
|
|
||||||
Value &indicesResult) const {
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
|
||||||
this->getTypeConverter()->convertType(op->getResult(1).getType()));
|
|
||||||
Value cstMinusOne =
|
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
|
|
||||||
Value indicesTensor =
|
|
||||||
createInitTensor(rewriter, loc, outTensorShape,
|
|
||||||
indicesRankedTensorType.getElementType(), cstMinusOne);
|
|
||||||
|
|
||||||
SmallVector<Value> kernelSize =
|
|
||||||
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
|
||||||
SmallVector<Value> padding =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
|
||||||
SmallVector<Value> dilation =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, dilationInts);
|
|
||||||
SmallVector<Value> kernelStride =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, strideInts);
|
|
||||||
|
|
||||||
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
|
||||||
loc, getAsOpFoldResult(kernelSize),
|
|
||||||
indicesRankedTensorType.getElementType());
|
|
||||||
|
|
||||||
SmallVector<AffineExpr> inputExprs, outputExprs, kernelExprs;
|
|
||||||
for (unsigned i = 0; i < rank; i++) {
|
|
||||||
inputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
||||||
outputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < rank - 2; i++) {
|
|
||||||
kernelExprs.push_back(rewriter.getAffineDimExpr(i + rank));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If computing indices for maxpool2d, we have six dimensions here. Each
|
|
||||||
// corresponding to N, C, Hout, Wout, kH, and kW, respectively, as described
|
|
||||||
// in the algorithm above.
|
|
||||||
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
|
||||||
{inputExprs, kernelExprs, outputExprs}, rewriter.getContext());
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
|
||||||
rank, utils::IteratorType::parallel);
|
|
||||||
iteratorTypes.append(rank - 2, utils::IteratorType::reduction);
|
|
||||||
|
|
||||||
// Extract pooling dimensions of input shape.
|
|
||||||
SmallVector<Value> inputSubShape;
|
|
||||||
for (unsigned i = 0; i < rank - 2; i++) {
|
|
||||||
inputSubShape.push_back(
|
|
||||||
getDimOp(rewriter, loc, adaptor.getSelf(), i + 2));
|
|
||||||
}
|
|
||||||
|
|
||||||
indicesResult =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, /*resultTensorTypes=*/indicesTensor.getType(),
|
|
||||||
/*inputs=*/ValueRange({maxPool, windowTensor}),
|
|
||||||
/*outputs=*/indicesTensor,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value maxVal = args[0], res = args[2];
|
|
||||||
|
|
||||||
SmallVector<Value> inputDims;
|
|
||||||
inputDims.append({b.create<linalg::IndexOp>(loc, 0),
|
|
||||||
b.create<linalg::IndexOp>(loc, 1)});
|
|
||||||
for (unsigned i = 2; i < rank; i++) {
|
|
||||||
Value mainIndex = b.create<linalg::IndexOp>(loc, i);
|
|
||||||
Value subIndex =
|
|
||||||
b.create<linalg::IndexOp>(loc, i + rank - 2);
|
|
||||||
Value origin = b.create<arith::MulIOp>(loc, mainIndex,
|
|
||||||
kernelStride[i - 2]);
|
|
||||||
Value offset =
|
|
||||||
b.create<arith::MulIOp>(loc, subIndex, dilation[i - 2]);
|
|
||||||
inputDims.push_back(
|
|
||||||
b.create<arith::AddIOp>(loc, origin, offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value input =
|
|
||||||
b.create<tensor::ExtractOp>(loc, paddedInput, inputDims);
|
|
||||||
Value pred = b.create<arith::CmpFOp>(
|
|
||||||
loc, arith::CmpFPredicate::OEQ, input, maxVal);
|
|
||||||
|
|
||||||
Value outIndex =
|
|
||||||
b.create<arith::ConstantOp>(loc, b.getIndexAttr(0));
|
|
||||||
Value curInputStride =
|
|
||||||
b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
|
||||||
for (unsigned i = 0; i < rank - 2; i++) {
|
|
||||||
Value minusPadding = b.create<arith::SubIOp>(
|
|
||||||
loc, inputDims[rank - 1 - i], padding[rank - 3 - i]);
|
|
||||||
Value timesStride = b.create<arith::MulIOp>(
|
|
||||||
loc, minusPadding, curInputStride);
|
|
||||||
outIndex =
|
|
||||||
b.create<arith::AddIOp>(loc, outIndex, timesStride);
|
|
||||||
curInputStride = b.create<arith::MulIOp>(
|
|
||||||
loc, curInputStride, inputSubShape[rank - 3 - i]);
|
|
||||||
}
|
|
||||||
Value result = b.create<arith::SelectOp>(
|
|
||||||
loc, pred, castIndexToInt64(b, loc, outIndex), res);
|
|
||||||
|
|
||||||
Value predInvalidIndex = b.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::eq, res, cstMinusOne);
|
|
||||||
Value out = b.create<arith::SelectOp>(loc, predInvalidIndex,
|
|
||||||
result, res);
|
|
||||||
|
|
||||||
b.create<linalg::YieldOp>(loc, out);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||||
|
@ -546,32 +454,124 @@ public:
|
||||||
paddedInput, maxPool)))
|
paddedInput, maxPool)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||||
} else {
|
} else {
|
||||||
if (failed(createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues,
|
ValueRange poolingResults;
|
||||||
strideInts, paddingInts, dilationInts,
|
if (failed(createCustomMaxPoolingOp(
|
||||||
ceilMode, outTensorShape, paddedInput,
|
op, adaptor, rewriter, typeConverter, kernelSizeIntValues,
|
||||||
maxPool)))
|
strideInts, paddingInts, dilationInts, ceilMode, outTensorShape,
|
||||||
|
paddedInput, poolingResults)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d");
|
||||||
|
maxPool = poolingResults.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value outMaxPool = rewriter.create<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, maxPoolResultType, maxPool);
|
||||||
op->getLoc(), maxPoolResultType, maxPool);
|
return success();
|
||||||
SmallVector<Value> outResult({outMaxPool});
|
}
|
||||||
if (withIndices) {
|
};
|
||||||
Value indicesResult;
|
} // namespace
|
||||||
if (failed(computeMaxPoolingIndices(
|
|
||||||
maxPool, paddedInput, op, adaptor, rewriter, outTensorShape,
|
|
||||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
|
||||||
selfRank, indicesResult)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"unable to compute maxpool indices");
|
|
||||||
Type indicesResultType =
|
|
||||||
typeConverter->convertType(op->getResult(1).getType());
|
|
||||||
Value outIndices = rewriter.create<tensor::CastOp>(
|
|
||||||
op->getLoc(), indicesResultType, indicesResult);
|
|
||||||
outResult.push_back(outIndices);
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, outResult);
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename OpTy>
|
||||||
|
class ConvertAtenMaxPoolWithIndicesOp : public OpConversionPattern<OpTy> {
|
||||||
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
|
||||||
|
|
||||||
|
public:
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
|
||||||
|
bool ceilMode;
|
||||||
|
SmallVector<Value, Dim> kernelSizeIntValues;
|
||||||
|
SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts;
|
||||||
|
if (!matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilationInts)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"only support constant int dilations");
|
||||||
|
|
||||||
|
if (failed(checkAndGetPoolingParameters<OpTy>(op, rewriter, typeConverter,
|
||||||
|
ceilMode, kernelSizeIntValues,
|
||||||
|
strideInts, paddingInts)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||||
|
|
||||||
|
// Initialize padding/dilation/kernelStride to help computing indices
|
||||||
|
// correspond to max pooling values.
|
||||||
|
SmallVector<Value> kernelSize =
|
||||||
|
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
||||||
|
SmallVector<Value> padding =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
||||||
|
SmallVector<Value> dilation =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, dilationInts);
|
||||||
|
SmallVector<Value> kernelStride =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, strideInts);
|
||||||
|
// Extract pooling dimensions of input shape.
|
||||||
|
SmallVector<Value> inputSubShape;
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
inputSubShape.push_back(
|
||||||
|
getDimOp(rewriter, loc, adaptor.getSelf(), i + 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indicesComputation = [&](OpBuilder &b, Location loc,
|
||||||
|
ValueRange iteratorDims) -> Value {
|
||||||
|
SmallVector<Value> inputDims;
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
Value origin =
|
||||||
|
b.create<arith::MulIOp>(loc, iteratorDims[i], kernelStride[i]);
|
||||||
|
Value offset =
|
||||||
|
b.create<arith::MulIOp>(loc, iteratorDims[i + Dim], dilation[i]);
|
||||||
|
inputDims.push_back(b.create<arith::AddIOp>(loc, origin, offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value outIndex = b.create<arith::ConstantOp>(loc, b.getIndexAttr(0));
|
||||||
|
Value curInputStride =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
||||||
|
Value validIndex =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(b.getI1Type(), 1));
|
||||||
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getIndexAttr(0));
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
Value minusPadding = b.create<arith::SubIOp>(
|
||||||
|
loc, inputDims[Dim - 1 - i], padding[Dim - 1 - i]);
|
||||||
|
validIndex = b.create<arith::AndIOp>(
|
||||||
|
loc, validIndex,
|
||||||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||||
|
minusPadding, cstZero));
|
||||||
|
Value timesStride =
|
||||||
|
b.create<arith::MulIOp>(loc, minusPadding, curInputStride);
|
||||||
|
outIndex = b.create<arith::AddIOp>(loc, outIndex, timesStride);
|
||||||
|
curInputStride = b.create<arith::MulIOp>(loc, curInputStride,
|
||||||
|
inputSubShape[Dim - 1 - i]);
|
||||||
|
}
|
||||||
|
return b.create<arith::SelectOp>(
|
||||||
|
loc, validIndex, castIndexToInt64(b, loc, outIndex),
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(-1)));
|
||||||
|
};
|
||||||
|
|
||||||
|
Value paddedInput;
|
||||||
|
SmallVector<Value, 4> outTensorShape;
|
||||||
|
ValueRange results;
|
||||||
|
if (failed(createCustomMaxPoolingOp(
|
||||||
|
op, adaptor, rewriter, typeConverter, kernelSizeIntValues,
|
||||||
|
strideInts, paddingInts, dilationInts, ceilMode, outTensorShape,
|
||||||
|
paddedInput, results, std::move(indicesComputation))))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unable to compute maxpool with indices");
|
||||||
|
|
||||||
|
Type maxPoolResultType =
|
||||||
|
typeConverter->convertType(op->getResult(0).getType());
|
||||||
|
Type indicesResultType =
|
||||||
|
typeConverter->convertType(op->getResult(1).getType());
|
||||||
|
Value outMaxpool = rewriter.create<tensor::CastOp>(loc, maxPoolResultType,
|
||||||
|
results.front());
|
||||||
|
Value outIndices =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, indicesResultType, results.back());
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {outMaxpool, outIndices});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1521,10 +1521,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||||
|
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||||
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
|
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
|
||||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
patterns.add<ConvertAtenMaxPoolWithIndicesOp<AtenMaxPool2dWithIndicesOp>>(
|
||||||
context);
|
typeConverter, context);
|
||||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
|
patterns.add<ConvertAtenMaxPoolWithIndicesOp<AtenMaxPool3dWithIndicesOp>>(
|
||||||
context);
|
typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
||||||
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
||||||
|
|
Loading…
Reference in New Issue