mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Support lowering MaxPool3dWithIndices (#3652)
Support torch.MaxPool3dWithIndices lowering to linalg backend.pull/3675/head
parent
b92e61832f
commit
5bc59ce1fa
|
@ -224,28 +224,41 @@ template <> struct DimensionTraits<AtenMaxPool2dOp> {
|
||||||
static_assert(Dim == Dim);
|
static_assert(Dim == Dim);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DimensionTraits<AtenMaxPool2dWithIndicesOp>
|
||||||
|
: DimensionTraits<AtenMaxPool2dOp> {};
|
||||||
|
|
||||||
template <> struct DimensionTraits<AtenMaxPool3dOp> {
|
template <> struct DimensionTraits<AtenMaxPool3dOp> {
|
||||||
static constexpr int64_t Dim = 3;
|
static constexpr int64_t Dim = 3;
|
||||||
// unused const variable warning suppression:
|
// unused const variable warning suppression:
|
||||||
static_assert(Dim == Dim);
|
static_assert(Dim == Dim);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DimensionTraits<AtenMaxPool3dWithIndicesOp>
|
||||||
|
: DimensionTraits<AtenMaxPool3dOp> {};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
|
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
|
||||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
static const bool withIndices =
|
||||||
|
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
|
||||||
|
AtenMaxPool3dWithIndicesOp>::value;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
|
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
|
||||||
|
|
||||||
LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op,
|
LogicalResult createPoolingMax3D(OpTy &op, typename OpTy::Adaptor adaptor,
|
||||||
typename OpTy::Adaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||||
SmallVectorImpl<int64_t> &strideInts,
|
SmallVectorImpl<int64_t> &strideInts,
|
||||||
SmallVectorImpl<int64_t> &paddingInts,
|
SmallVectorImpl<int64_t> &paddingInts,
|
||||||
SmallVectorImpl<int64_t> &dilationInts,
|
SmallVectorImpl<int64_t> &dilationInts,
|
||||||
bool ceilMode) const {
|
bool ceilMode,
|
||||||
SmallVector<Value, 5> outTensorShape;
|
SmallVectorImpl<Value> &outTensorShape,
|
||||||
|
Value &paddedInput, Value &poolingOp) const {
|
||||||
|
static_assert(Dim == 3, "op must be MaxPool3d or MaxPool3dWithIndices");
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
|
@ -255,8 +268,8 @@ private:
|
||||||
Value initValue =
|
Value initValue =
|
||||||
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
|
||||||
|
|
||||||
Value paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3,
|
paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, strideInts,
|
||||||
strideInts, paddingInts, initValue);
|
paddingInts, initValue);
|
||||||
|
|
||||||
auto outTensorInitialized = computeOutputTensor(
|
auto outTensorInitialized = computeOutputTensor(
|
||||||
op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts,
|
op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts,
|
||||||
|
@ -309,25 +322,160 @@ private:
|
||||||
SmallVector<utils::IteratorType>(5, utils::IteratorType::parallel);
|
SmallVector<utils::IteratorType>(5, utils::IteratorType::parallel);
|
||||||
iteratorTypes.append(3, utils::IteratorType::reduction);
|
iteratorTypes.append(3, utils::IteratorType::reduction);
|
||||||
SmallVector<AffineMap> indexingMaps = {mapInput, mapKernel, mapOutput};
|
SmallVector<AffineMap> indexingMaps = {mapInput, mapKernel, mapOutput};
|
||||||
Value poolingOp =
|
poolingOp = rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
/* result types */ outTensorInitialized.getType(),
|
||||||
|
/* operands */ ValueRange({paddedInput, windowTensor}),
|
||||||
|
/* outputs */ outTensorInitialized,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value currentVal = args[0], accMaxValue = args[2];
|
||||||
|
Value max_result = b.create<arith::MaximumFOp>(
|
||||||
|
loc, currentVal, accMaxValue);
|
||||||
|
b.create<linalg::YieldOp>(loc, max_result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the corresponding indices of the input tensor for the max pooling
|
||||||
|
// result tensor.
|
||||||
|
//
|
||||||
|
// For finding the indices, we follow the below method:
|
||||||
|
//
|
||||||
|
// Take maxpool2d as an example to illustrate. Let's say the input tensor is a
|
||||||
|
// 4-d tensor. The maxpool2d and indices will also be a 4-d tensor. Then:
|
||||||
|
// for i in range(N):
|
||||||
|
// for j in range(C):
|
||||||
|
// for m in range(Hout):
|
||||||
|
// for n in range(Wout):
|
||||||
|
// for p in range(kH):
|
||||||
|
// for r in range(kW):
|
||||||
|
// indexH = m * stride[0] + p * dilation[0]
|
||||||
|
// indexW = n * stride[0] + r * dilation[0]
|
||||||
|
// if paddedInput[i, j, indexH, indexW] ==
|
||||||
|
// maxPool2d[i, j, m, n]:
|
||||||
|
// indices[i, j, m, n] =
|
||||||
|
// (indexH - padding[0]) * W +
|
||||||
|
// (indexW - padding[1])
|
||||||
|
//
|
||||||
|
LogicalResult
|
||||||
|
computeMaxPoolingIndices(Value maxPool, Value paddedInput, OpTy &op,
|
||||||
|
typename OpTy::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
SmallVectorImpl<Value> &outTensorShape,
|
||||||
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||||
|
SmallVectorImpl<int64_t> &strideInts,
|
||||||
|
SmallVectorImpl<int64_t> &paddingInts,
|
||||||
|
SmallVectorImpl<int64_t> &dilationInts, int64_t rank,
|
||||||
|
Value &indicesResult) const {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
||||||
|
this->getTypeConverter()->convertType(op->getResult(1).getType()));
|
||||||
|
Value cstMinusOne =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
Value indicesTensor =
|
||||||
|
createInitTensor(rewriter, loc, outTensorShape,
|
||||||
|
indicesRankedTensorType.getElementType(), cstMinusOne);
|
||||||
|
|
||||||
|
SmallVector<Value> kernelSize =
|
||||||
|
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
||||||
|
SmallVector<Value> padding =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
||||||
|
SmallVector<Value> dilation =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, dilationInts);
|
||||||
|
SmallVector<Value> kernelStride =
|
||||||
|
getAsConstantIndexValues(rewriter, loc, strideInts);
|
||||||
|
|
||||||
|
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, getAsOpFoldResult(kernelSize),
|
||||||
|
indicesRankedTensorType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> inputExprs, outputExprs, kernelExprs;
|
||||||
|
for (unsigned i = 0; i < rank; i++) {
|
||||||
|
inputExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
outputExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < rank - 2; i++) {
|
||||||
|
kernelExprs.push_back(rewriter.getAffineDimExpr(i + rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If computing indices for maxpool2d, we have six dimensions here. Each
|
||||||
|
// corresponding to N, C, Hout, Wout, kH, and kW, respectively, as described
|
||||||
|
// in the algorithm above.
|
||||||
|
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
||||||
|
{inputExprs, kernelExprs, outputExprs}, rewriter.getContext());
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
rank, utils::IteratorType::parallel);
|
||||||
|
iteratorTypes.append(rank - 2, utils::IteratorType::reduction);
|
||||||
|
|
||||||
|
// Extract pooling dimensions of input shape.
|
||||||
|
SmallVector<Value> inputSubShape;
|
||||||
|
for (unsigned i = 0; i < rank - 2; i++) {
|
||||||
|
inputSubShape.push_back(
|
||||||
|
getDimOp(rewriter, loc, adaptor.getSelf(), i + 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
indicesResult =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
op->getLoc(),
|
loc, /*resultTensorTypes=*/indicesTensor.getType(),
|
||||||
/* result types */ outTensorInitialized.getType(),
|
/*inputs=*/ValueRange({maxPool, windowTensor}),
|
||||||
/* operands */ ValueRange({paddedInput, windowTensor}),
|
/*outputs=*/indicesTensor,
|
||||||
/* outputs */ outTensorInitialized,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
/*indexingMaps=*/indexingMaps,
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value currentVal = args[0], accMaxValue = args[2];
|
Value maxVal = args[0], res = args[2];
|
||||||
Value max_result =
|
|
||||||
b.create<arith::MaximumFOp>(loc, currentVal, accMaxValue);
|
SmallVector<Value> inputDims;
|
||||||
;
|
inputDims.append({b.create<linalg::IndexOp>(loc, 0),
|
||||||
b.create<linalg::YieldOp>(loc, max_result);
|
b.create<linalg::IndexOp>(loc, 1)});
|
||||||
|
for (unsigned i = 2; i < rank; i++) {
|
||||||
|
Value mainIndex = b.create<linalg::IndexOp>(loc, i);
|
||||||
|
Value subIndex =
|
||||||
|
b.create<linalg::IndexOp>(loc, i + rank - 2);
|
||||||
|
Value origin = b.create<arith::MulIOp>(loc, mainIndex,
|
||||||
|
kernelStride[i - 2]);
|
||||||
|
Value offset =
|
||||||
|
b.create<arith::MulIOp>(loc, subIndex, dilation[i - 2]);
|
||||||
|
inputDims.push_back(
|
||||||
|
b.create<arith::AddIOp>(loc, origin, offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value input =
|
||||||
|
b.create<tensor::ExtractOp>(loc, paddedInput, inputDims);
|
||||||
|
Value pred = b.create<arith::CmpFOp>(
|
||||||
|
loc, arith::CmpFPredicate::OEQ, input, maxVal);
|
||||||
|
|
||||||
|
Value outIndex =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getIndexAttr(0));
|
||||||
|
Value curInputStride =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
||||||
|
for (unsigned i = 0; i < rank - 2; i++) {
|
||||||
|
Value minusPadding = b.create<arith::SubIOp>(
|
||||||
|
loc, inputDims[rank - 1 - i], padding[rank - 3 - i]);
|
||||||
|
Value timesStride = b.create<arith::MulIOp>(
|
||||||
|
loc, minusPadding, curInputStride);
|
||||||
|
outIndex =
|
||||||
|
b.create<arith::AddIOp>(loc, outIndex, timesStride);
|
||||||
|
curInputStride = b.create<arith::MulIOp>(
|
||||||
|
loc, curInputStride, inputSubShape[rank - 3 - i]);
|
||||||
|
}
|
||||||
|
Value result = b.create<arith::SelectOp>(
|
||||||
|
loc, pred, castIndexToInt64(b, loc, outIndex), res);
|
||||||
|
|
||||||
|
Value predInvalidIndex = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, res, cstMinusOne);
|
||||||
|
Value out = b.create<arith::SelectOp>(loc, predInvalidIndex,
|
||||||
|
result, res);
|
||||||
|
|
||||||
|
b.create<linalg::YieldOp>(loc, out);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, poolingOp);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -377,214 +525,53 @@ public:
|
||||||
if (!smallestValueAttr)
|
if (!smallestValueAttr)
|
||||||
return rewriter.notifyMatchFailure(op, "invalid element type");
|
return rewriter.notifyMatchFailure(op, "invalid element type");
|
||||||
|
|
||||||
|
// `maxPool` contains the result of maxpool 1d/2d/3d operation over the
|
||||||
|
// input, `paddedInput` means the padded result of input tensor.
|
||||||
|
Value maxPool, paddedInput;
|
||||||
|
Type maxPoolResultType =
|
||||||
|
typeConverter->convertType(op->getResult(0).getType());
|
||||||
|
SmallVector<Value, 5> outTensorShape;
|
||||||
if constexpr (Dim == 1) {
|
if constexpr (Dim == 1) {
|
||||||
SmallVector<Value, 4> outTensorShape;
|
|
||||||
Value maxPool1d, paddedInput;
|
|
||||||
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
|
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||||
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
|
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
|
||||||
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
||||||
paddedInput, maxPool1d)))
|
paddedInput, maxPool)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
|
||||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool1d);
|
|
||||||
return success();
|
|
||||||
} else if constexpr (Dim == 2) {
|
} else if constexpr (Dim == 2) {
|
||||||
SmallVector<Value, 4> outTensorShape;
|
|
||||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
|
||||||
Value maxPool2d, paddedInput;
|
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||||
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
|
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
|
||||||
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
||||||
paddedInput, maxPool2d)))
|
paddedInput, maxPool)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
|
||||||
return success();
|
|
||||||
} else {
|
} else {
|
||||||
return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues,
|
if (failed(createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues,
|
||||||
strideInts, paddingInts, dilationInts,
|
strideInts, paddingInts, dilationInts,
|
||||||
ceilMode);
|
ceilMode, outTensorShape, paddedInput,
|
||||||
|
maxPool)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
Value outMaxPool = rewriter.create<tensor::CastOp>(
|
||||||
// Returns the result of maxpool2d over the input tensor. And the corresponding
|
op->getLoc(), maxPoolResultType, maxPool);
|
||||||
// indices of the input tensor for the values of the result tensor.
|
SmallVector<Value> outResult({outMaxPool});
|
||||||
//
|
if (withIndices) {
|
||||||
// The result of the maxpool2d operation is calculated using the helper function
|
Value indicesResult;
|
||||||
// written above. For finding the indices, we follow the below method:
|
if (failed(computeMaxPoolingIndices(
|
||||||
//
|
maxPool, paddedInput, op, adaptor, rewriter, outTensorShape,
|
||||||
// Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will
|
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||||
// also be a 4-d tensor. Then:
|
selfRank, indicesResult)))
|
||||||
// for i in range(N):
|
return rewriter.notifyMatchFailure(op,
|
||||||
// for j in range(C):
|
"unable to compute maxpool indices");
|
||||||
// for m in range(Hout):
|
Type indicesResultType =
|
||||||
// for n in range(Wout):
|
typeConverter->convertType(op->getResult(1).getType());
|
||||||
// for p in range(kH):
|
Value outIndices = rewriter.create<tensor::CastOp>(
|
||||||
// for r in range(kW):
|
op->getLoc(), indicesResultType, indicesResult);
|
||||||
// indexH = m * stride[0] + p * dilation[0]
|
outResult.push_back(outIndices);
|
||||||
// indexW = n * stride[0] + r * dilation[0]
|
|
||||||
// if paddedInput[i, j, indexH, indexW] ==
|
|
||||||
// maxPool2d[i, j, m, n]:
|
|
||||||
// indices[i, j, m, n] = (indexH - padding[0]) * W +
|
|
||||||
// (indexW - padding[1])
|
|
||||||
//
|
|
||||||
class ConvertAtenMaxPool2dWithIndicesOp
|
|
||||||
: public OpConversionPattern<AtenMaxPool2dWithIndicesOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
||||||
return failure();
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
|
||||||
Value self = adaptor.getSelf();
|
|
||||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
|
||||||
Type elementType = selfType.getElementType();
|
|
||||||
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
|
||||||
getTypeConverter()->convertType(op->getResult(1).getType()));
|
|
||||||
|
|
||||||
// TODO: Add support for 3D inputs.
|
|
||||||
if (selfType.getRank() == 3)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only support 4D input");
|
|
||||||
|
|
||||||
bool ceilMode;
|
|
||||||
SmallVector<Value, 2> kernelSizeIntValues;
|
|
||||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
|
||||||
if (!matchPattern(op.getDilation(),
|
|
||||||
m_TorchListOfConstantInts(dilationInts)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"only support constant int dilations");
|
|
||||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
|
|
||||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
|
||||||
strideInts, paddingInts)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
|
||||||
|
|
||||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
|
||||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
|
||||||
elementType,
|
|
||||||
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
|
||||||
/*Negative=*/true));
|
|
||||||
Value maxPool2d, paddedInput;
|
|
||||||
SmallVector<Value, 4> outTensorShape;
|
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
|
||||||
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
|
||||||
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
|
|
||||||
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
|
|
||||||
maxPool2d)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
|
||||||
|
|
||||||
Value cstMinusOne =
|
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
|
|
||||||
Value indicesTensor =
|
|
||||||
createInitTensor(rewriter, loc, outTensorShape,
|
|
||||||
indicesRankedTensorType.getElementType(), cstMinusOne);
|
|
||||||
|
|
||||||
SmallVector<Value> kernelSize =
|
|
||||||
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
|
||||||
SmallVector<Value> padding =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
|
||||||
SmallVector<Value> dilation =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, dilationInts);
|
|
||||||
SmallVector<Value> stride =
|
|
||||||
getAsConstantIndexValues(rewriter, loc, strideInts);
|
|
||||||
|
|
||||||
Value windowTensor = rewriter.create<tensor::EmptyOp>(
|
|
||||||
loc, getAsOpFoldResult(kernelSize),
|
|
||||||
indicesRankedTensorType.getElementType());
|
|
||||||
|
|
||||||
SmallVector<AffineExpr> inputExprs, outputExprs, kernelExprs;
|
|
||||||
for (unsigned i = 0; i < 4; i++) {
|
|
||||||
inputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
||||||
outputExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
||||||
}
|
}
|
||||||
kernelExprs.push_back(rewriter.getAffineDimExpr(4));
|
rewriter.replaceOp(op, outResult);
|
||||||
kernelExprs.push_back(rewriter.getAffineDimExpr(5));
|
|
||||||
|
|
||||||
// Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH,
|
|
||||||
// and kW, respectively, as described in the algorithm above.
|
|
||||||
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
|
||||||
{inputExprs, kernelExprs, outputExprs}, rewriter.getContext());
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
|
||||||
4, utils::IteratorType::parallel);
|
|
||||||
iteratorTypes.push_back(utils::IteratorType::reduction);
|
|
||||||
iteratorTypes.push_back(utils::IteratorType::reduction);
|
|
||||||
|
|
||||||
// Input format is : [N, C, H, W]
|
|
||||||
Value inputShapeW = getDimOp(rewriter, loc, self, 3);
|
|
||||||
|
|
||||||
Value indicesResult =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, /*resultTensorTypes=*/indicesTensor.getType(),
|
|
||||||
/*inputs=*/ValueRange({maxPool2d, windowTensor}),
|
|
||||||
/*outputs=*/indicesTensor,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value maxVal = args[0], res = args[2];
|
|
||||||
|
|
||||||
Value i = b.create<linalg::IndexOp>(loc, 0);
|
|
||||||
Value j = b.create<linalg::IndexOp>(loc, 1);
|
|
||||||
Value m = b.create<linalg::IndexOp>(loc, 2);
|
|
||||||
Value n = b.create<linalg::IndexOp>(loc, 3);
|
|
||||||
Value p = b.create<linalg::IndexOp>(loc, 4);
|
|
||||||
Value r = b.create<linalg::IndexOp>(loc, 5);
|
|
||||||
|
|
||||||
Value mTimesStride =
|
|
||||||
b.create<arith::MulIOp>(loc, m, stride[0]);
|
|
||||||
Value pTimesDilation =
|
|
||||||
b.create<arith::MulIOp>(loc, p, dilation[0]);
|
|
||||||
Value indexH = b.create<arith::AddIOp>(loc, mTimesStride,
|
|
||||||
pTimesDilation);
|
|
||||||
Value nTimesStride =
|
|
||||||
b.create<arith::MulIOp>(loc, n, stride[1]);
|
|
||||||
Value rTimesDilation =
|
|
||||||
b.create<arith::MulIOp>(loc, r, dilation[1]);
|
|
||||||
Value indexW = b.create<arith::AddIOp>(loc, nTimesStride,
|
|
||||||
rTimesDilation);
|
|
||||||
Value input = b.create<tensor::ExtractOp>(
|
|
||||||
loc, paddedInput, ValueRange{i, j, indexH, indexW});
|
|
||||||
Value pred = b.create<arith::CmpFOp>(
|
|
||||||
loc, arith::CmpFPredicate::OEQ, input, maxVal);
|
|
||||||
|
|
||||||
Value indexHMinusPadding =
|
|
||||||
b.create<arith::SubIOp>(loc, indexH, padding[0]);
|
|
||||||
Value indexWMinusPadding =
|
|
||||||
b.create<arith::SubIOp>(loc, indexW, padding[1]);
|
|
||||||
Value outIndex = b.create<arith::MulIOp>(
|
|
||||||
loc, indexHMinusPadding, inputShapeW);
|
|
||||||
outIndex = b.create<arith::AddIOp>(loc, outIndex,
|
|
||||||
indexWMinusPadding);
|
|
||||||
Value result = b.create<arith::SelectOp>(
|
|
||||||
loc, pred, castIndexToInt64(b, loc, outIndex), res);
|
|
||||||
|
|
||||||
Value predInvalidIndex = b.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::eq, res, cstMinusOne);
|
|
||||||
Value out = b.create<arith::SelectOp>(loc, predInvalidIndex,
|
|
||||||
result, res);
|
|
||||||
|
|
||||||
b.create<linalg::YieldOp>(loc, out);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
Type maxPool2dResultType =
|
|
||||||
getTypeConverter()->convertType(op->getResult(0).getType());
|
|
||||||
Type indicesResultType =
|
|
||||||
getTypeConverter()->convertType(op->getResult(1).getType());
|
|
||||||
Value outMaxpool2d =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, maxPool2dResultType, maxPool2d);
|
|
||||||
Value outIndices =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, indicesResultType, indicesResult);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {outMaxpool2d, outIndices});
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1533,7 +1520,11 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);
|
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
|
||||||
|
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
||||||
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
||||||
|
|
|
@ -8163,6 +8163,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg1 : !torch.list<int>\n"
|
" return %arg1 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.max_pool3d_with_indices\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
|
" %0 = call @__torch__._max_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
||||||
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
||||||
|
@ -11949,6 +11954,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
" return %1 : !torch.tuple<int, int>\n"
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool3d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -448,13 +448,6 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"IntFloatModule_basic",
|
"IntFloatModule_basic",
|
||||||
"IntImplicitModule_basic",
|
"IntImplicitModule_basic",
|
||||||
"LenStrModule_basic",
|
"LenStrModule_basic",
|
||||||
"MaxPool3dCeilModeTrueModule_basic",
|
|
||||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
|
||||||
"MaxPool3dLargeDatadModule_basic",
|
|
||||||
"MaxPool3dModuleRandomSimple_basic",
|
|
||||||
"MaxPool3dModule_basic",
|
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
|
||||||
"MaxPool3dStaticModule_basic",
|
|
||||||
"MulFloatModule_basic",
|
"MulFloatModule_basic",
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
|
@ -707,6 +700,16 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllOnesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
|
||||||
|
"MaxPool3dWithIndicesFullSizeKernelModule_basic",
|
||||||
|
"MaxPool3dWithIndicesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
|
"MaxPool3dWithIndicesStaticModule_basic",
|
||||||
"MseLossMeanReductionModule_basic",
|
"MseLossMeanReductionModule_basic",
|
||||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
"MulFloatModule_basic",
|
"MulFloatModule_basic",
|
||||||
|
@ -2585,6 +2588,13 @@ ONNX_XFAIL_SET = {
|
||||||
"MaxPool3dLargeDatadModule_basic",
|
"MaxPool3dLargeDatadModule_basic",
|
||||||
"MaxPool3dModuleRandomSimple_basic",
|
"MaxPool3dModuleRandomSimple_basic",
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllOnesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
|
||||||
|
"MaxPool3dWithIndicesFullSizeKernelModule_basic",
|
||||||
|
"MaxPool3dWithIndicesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
"MaxUnpool3dModule_basic",
|
"MaxUnpool3dModule_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpool3dModulePad0_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
|
@ -2914,6 +2924,8 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
# Runtime crash: mismatched size for broadcast
|
# Runtime crash: mismatched size for broadcast
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
"StdDimEmptyDimModule_basic",
|
"StdDimEmptyDimModule_basic",
|
||||||
"StdCorrectionEmptyDimModule_basic",
|
"StdCorrectionEmptyDimModule_basic",
|
||||||
"VarCorrectionEmptyDimModule_basic",
|
"VarCorrectionEmptyDimModule_basic",
|
||||||
|
@ -3372,6 +3384,16 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllOnesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
|
||||||
|
"MaxPool3dWithIndicesFullSizeKernelModule_basic",
|
||||||
|
"MaxPool3dWithIndicesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
|
"MaxPool3dWithIndicesStaticModule_basic",
|
||||||
"MeanDimDtypeModule_basic",
|
"MeanDimDtypeModule_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"MeanDimNoneDimModule_basic",
|
||||||
|
@ -4244,6 +4266,16 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesAllOnesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
|
||||||
|
"MaxPool3dWithIndicesFullSizeKernelModule_basic",
|
||||||
|
"MaxPool3dWithIndicesModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
|
"MaxPool3dWithIndicesStaticModule_basic",
|
||||||
"MeanDimAllReduceKeepdimModule_basic",
|
"MeanDimAllReduceKeepdimModule_basic",
|
||||||
"MeanDimAllReduceModule_basic",
|
"MeanDimAllReduceModule_basic",
|
||||||
"MeanDimDtypeModule_basic",
|
"MeanDimDtypeModule_basic",
|
||||||
|
|
|
@ -1046,6 +1046,10 @@ def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[in
|
||||||
def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
|
||||||
|
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
return maxpool3d, indices
|
||||||
|
|
||||||
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
||||||
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
||||||
assert (len(output_size) == 3), "output_size must have 3 elements"
|
assert (len(output_size) == 3), "output_size must have 3 elements"
|
||||||
|
@ -3118,6 +3122,11 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype, torch.int64
|
return self_dtype, torch.int64
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2]))
|
||||||
|
def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> Tuple[int, int]:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype, torch.int64
|
||||||
|
|
||||||
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
|
@ -956,6 +956,252 @@ def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x,
|
||||||
|
kernel_size=[2, 2, 2],
|
||||||
|
stride=[1, 1, 1],
|
||||||
|
padding=[0, 0, 0],
|
||||||
|
dilation=[1, 1, 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesModule())
|
||||||
|
def MaxPool3dWithIndicesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 8, 8, 8, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesFullSizeKernelModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 4, 4], stride=1, padding=0, dilation=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesFullSizeKernelModule())
|
||||||
|
def MaxPool3dWithIndicesFullSizeKernelModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 4, 4, 4, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesNonDefaultPaddingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: MaxPool3dWithIndicesNonDefaultPaddingModule()
|
||||||
|
)
|
||||||
|
def MaxPool3dWithIndicesNonDefaultPaddingModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 16, 16, 16, low=-1.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesNonDefaultStrideModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 4, 4], stride=[1, 2, 1], padding=0, dilation=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesNonDefaultStrideModule())
|
||||||
|
def MaxPool3dWithIndicesNonDefaultStrideModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 16, 80, 16, low=0.5, high=2.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesNonDefaultDilationModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 4, 4], stride=[1, 1, 1], padding=0, dilation=[2, 2, 2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: MaxPool3dWithIndicesNonDefaultDilationModule()
|
||||||
|
)
|
||||||
|
def MaxPool3dWithIndicesNonDefaultDilationModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 16, 80, 16, low=0.5, high=2.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesNonDefaultParamsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x,
|
||||||
|
kernel_size=[8, 4, 8],
|
||||||
|
stride=[2, 2, 2],
|
||||||
|
padding=[1, 2, 1],
|
||||||
|
dilation=[2, 2, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesNonDefaultParamsModule())
|
||||||
|
def MaxPool3dWithIndicesNonDefaultParamsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 16, 80, 16, low=-0.5, high=4.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesAllNegativeValuesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: MaxPool3dWithIndicesAllNegativeValuesModule()
|
||||||
|
)
|
||||||
|
def MaxPool3dWithIndicesAllNegativeValuesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 16, 16, 16, low=-4.5, high=-1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 4, 16, 16, 16], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesStaticModule())
|
||||||
|
def MaxPool3dWithIndicesStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 16, 16, 16, low=-4.5, high=-1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesAllOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x,
|
||||||
|
kernel_size=[2, 2, 2],
|
||||||
|
stride=[1, 1, 1],
|
||||||
|
padding=[0, 0, 0],
|
||||||
|
dilation=[1, 1, 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesAllOnesModule())
|
||||||
|
def MaxPool3dWithIndicesAllOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.ones(1, 1, 8, 8, 8))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3dWithIndicesCeilModeTrueModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.max_pool3d_with_indices(
|
||||||
|
x,
|
||||||
|
kernel_size=[2, 2, 2],
|
||||||
|
stride=[1, 1, 1],
|
||||||
|
padding=[0, 0, 0],
|
||||||
|
dilation=[1, 1, 1],
|
||||||
|
ceil_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaxPool3dWithIndicesCeilModeTrueModule())
|
||||||
|
def MaxPool3dWithIndicesCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 8, 8, 8, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AvgPool2dFloatModule(torch.nn.Module):
|
class AvgPool2dFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue