`max_unpool3d` linalg lowering (#3536)

An attempt of  `aten.max_unpool3d` to linalg lowering.
There are known issues with this implementation (see comment in code).
pull/3576/head
Ivan Butygin 2024-07-30 20:59:17 +03:00 committed by GitHub
parent f1c74e1431
commit 8bd1b9751f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 398 additions and 0 deletions

View File

@ -11,6 +11,7 @@
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
@ -583,6 +584,258 @@ public:
}; };
} // namespace } // namespace
namespace {
// Max unpooling operation, takes result of max_pooling op and indices and
// tries to reconstructs original pooling input by filling out values by either
// values from self or zero.
// Upstream CPU implementation use parallel loop over the indices array to fill
// out tensor but such approach requires random access writes, which is tricky
// to represent in linalg.
// Instead we are using a different method: we are mapping each input/index
// value to multiple output values via affine maps in linalg.generic, then,
// inside the body of generic, we compute out index and compare it with expected
// index we got from input, returning either input or zero.
// This method only works if we have non-overlapping pooling windows.
// In case of overlap (e.g. kernel_size=2, stride=1) we need to map many-to-many
// input to output values and do a reduction. To construct such mapping we need
// to know original Kernel size, but it doesn't encoded in aten op. We cannot
// reconstruct kernel_size either as such reconstruction is ambiguous (e.g. for
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
// What worse, without knowing kernel size we cannot even reliably detect such
// cases and this conversion will just return invalid values.
class ConvertAtenMaxUnpool3dOp final
: public OpConversionPattern<AtenMaxUnpool3dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
const TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
if (ShapedType::isDynamicShape(inputSize))
return rewriter.notifyMatchFailure(op,
"input type must be of static shape");
Value indices = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indices.getType());
if (inputSize != indicesType.getShape().take_back(3))
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
if (!resType)
return rewriter.notifyMatchFailure(op, "invalid result type");
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
if (ShapedType::isDynamicShape(inferredOutSize))
return rewriter.notifyMatchFailure(op,
"output type must be of static shape");
{
SmallVector<int64_t> output;
if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(output)))
return rewriter.notifyMatchFailure(op,
"only support constant int output");
if (inferredOutSize != ArrayRef(output))
return rewriter.notifyMatchFailure(op, "Invalid output size");
}
SmallVector<int64_t> stride;
SmallVector<int64_t> padding;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))
return rewriter.notifyMatchFailure(op,
"only support constant int padding");
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
// (padding.size() == 6).
if (stride.size() != 3 || padding.size() != 3)
return rewriter.notifyMatchFailure(
op, "stride and padding must be of size 3");
int64_t outRank = resType.getRank();
int64_t NC = outRank - 3;
for (auto &&[inDim, outDim, str, pad] :
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
// Kernel size computation is ambiguous, this formula will return the
// biggest possible kernel size. As there is no way to know actual kernel
// size we have to treat it conservatively and always bail if kernel size
// potentially bigger than stride.
int64_t kernelSize = outDim - (inDim - 1) * str + 2 * pad;
if (kernelSize > str)
return rewriter.notifyMatchFailure(
op, "potential pooling windows overlapping is detected, this case "
"is not supported yet");
}
Type indexType = rewriter.getIndexType();
SmallVector<Value> outSizePadded;
for (auto &&[i, size] : llvm::enumerate(resType.getShape())) {
if (int64_t(i) < NC) {
outSizePadded.emplace_back(
rewriter.create<tensor::DimOp>(loc, self, i));
continue;
}
int64_t pad = padding[i - NC];
outSizePadded.emplace_back(
rewriter.create<arith::ConstantIndexOp>(loc, size + pad));
}
auto ceilDiv = [](int64_t v1, int64_t v2) -> int64_t {
return (v1 + v2 - 1) / v2;
};
// In case if input tensor size is not divisible by stride
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(3));
for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
if (expectedInputShape != selfType.getShape()) {
// TODO: this is probably expensive, and it may be possible to solve by
// cleverly constructing affine maps for the next linalg.generic op,
// but I'm not smart enough to figure this out.
SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
high.emplace_back(outSize - inpSize);
}
// Pad the indices tensor with a value which cannot appear in real data
// (-1) so it will never match. In this case we can pad self with any
// value, as it will never affect the output.
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(selfType.getElementType()));
Value invalidIdx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1));
self =
torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero);
indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low,
high, invalidIdx);
}
Value init = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outSizePadded), selfType.getElementType());
SmallVector<AffineExpr> inputExprs;
SmallVector<AffineExpr> outputExprs;
for (auto i : llvm::seq<int64_t>(0, outRank)) {
AffineExpr dim = rewriter.getAffineDimExpr(i);
if (i < NC) {
inputExprs.emplace_back(dim);
} else {
int64_t j = i - NC;
inputExprs.emplace_back(dim.floorDiv(stride[j]));
}
outputExprs.emplace_back(dim);
}
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
{inputExprs, inputExprs, outputExprs}, rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes(
outRank, utils::IteratorType::parallel);
auto computeIndex = [&](OpBuilder &b, Location loc) -> Value {
// Next linalg.generic uses identity mapping for the unpooled tensor,
// compute linear index for output element, which we will the compare with
// values which came from indices tensor.
Value ret;
for (auto i : llvm::seq<int64_t>(NC, outRank)) {
Value idx = b.create<linalg::IndexOp>(loc, i);
// If pool input was padded, adjust indices so they start at 0 in the
// non-padded area. Indices outside non-padded area will make no sense,
// but it doesnt matter as we will cut the padded area later by
// extract_slice.
int64_t pad = padding[i - NC];
if (pad != 0) {
Value padVal = b.create<arith::ConstantIndexOp>(loc, pad);
idx = b.create<arith::SubIOp>(loc, idx, padVal);
}
if (!ret) {
ret = idx;
} else {
Value size =
b.create<arith::ConstantIndexOp>(loc, resType.getShape()[i]);
ret = b.create<arith::MulIOp>(loc, ret, size);
ret = b.create<arith::AddIOp>(loc, ret, idx);
}
}
return ret;
};
auto builder = [&](OpBuilder &b, Location loc, ValueRange args) {
// Compute current output linear index and compare it with the value
// from indices arg.
Value input = args[0];
Value zero = b.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(input.getType()));
Value index = b.create<arith::IndexCastOp>(loc, indexType, args[1]);
Value currentIndex = computeIndex(b, loc);
Value cmp = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, index,
currentIndex);
Value out = b.create<arith::SelectOp>(loc, cmp, input, zero);
b.create<linalg::YieldOp>(loc, out);
};
Value result =
rewriter
.create<linalg::GenericOp>(loc,
/*resultTensorTypes=*/init.getType(),
/*inputs=*/ValueRange({self, indices}),
/*outputs=*/init,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes, builder)
.getResult(0);
if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) {
// MaxPool input was padded, unpad it by taking the slice.
SmallVector<OpFoldResult> offsetVals(NC, rewriter.getI64IntegerAttr(0));
for (int64_t pad : padding)
offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad));
SmallVector<OpFoldResult> sizeVals;
for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) {
if (!ShapedType::isDynamic(dim)) {
sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim));
continue;
}
sizeVals.emplace_back(rewriter.create<tensor::DimOp>(loc, self, i));
}
SmallVector<OpFoldResult> stridesVals(outRank,
rewriter.getI64IntegerAttr(1));
result = rewriter.create<tensor::ExtractSliceOp>(loc, result, offsetVals,
sizeVals, stridesVals);
}
if (result.getType() != resType)
result = rewriter.create<tensor::CastOp>(loc, resType, result);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace { namespace {
template <typename OpTy, typename PoolingOpTy, int Dim> template <typename OpTy, typename PoolingOpTy, int Dim>
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> { class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
@ -1275,6 +1528,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>(); target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context); patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenMaxUnpool3dOp>();
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>(); target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
patterns patterns
.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>( .add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(

View File

@ -8159,6 +8159,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n" " return %arg1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 4 or 5\"\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %9 = torch.aten.eq.int %8, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.prim.ListConstruct %11, %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %16 : !torch.list<int>\n"
" } else {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %15 : !torch.list<int>\n"
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
@ -11687,6 +11751,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n" " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n" " return %1 : !torch.tuple<int, int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.tuple<int, int> {\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n" " %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -2471,6 +2471,8 @@ ONNX_XFAIL_SET = {
"MaxPool3dLargeDatadModule_basic", "MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic", "MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic", "MaxPool3dModule_basic",
"MaxUnpool3dModule_basic",
"MaxUnpool3dModulePad0_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"Mlp1LayerModule_basic", "Mlp1LayerModule_basic",
"Mlp2LayerModuleNoBias_basic", "Mlp2LayerModuleNoBias_basic",

View File

@ -1043,6 +1043,15 @@ def atenmax_pool2d_with_indices〡shape(self: List[int], kernel_size: List[in
def atenmax_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: def atenmax_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self return self
def atenmax_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
assert (len(output_size) == 3), "output_size must have 3 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
if len(self) == 5:
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
else:
return [self[0], output_size[0], output_size[1], output_size[2]]
def atenupsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: def atenupsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return input_size return input_size
@ -3054,6 +3063,10 @@ def atenmax_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64 return self_dtype, torch.int64
def atenmax_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2]))
def atenadaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: def atenadaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -1698,3 +1698,61 @@ class AdaptiveMaxPool3dStaticWithIndices(torch.nn.Module):
@register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) @register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices())
def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10, 16, 17)) module.forward(tu.rand(1, 512, 10, 16, 17))
# ==============================================================================
class MaxUnpool3dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, 2, 2, 4], torch.float32, True),
([-1, -1, 2, 2, 4], torch.int64, True),
]
)
def forward(self, x, indices):
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1))
@register_test_case(module_factory=lambda: MaxUnpool3dModule())
def MaxUnpool3dModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 5, 6)
pool = torch.nn.MaxPool3d(
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True
)
output, indices = pool(input)
module.forward(output, indices)
# We have a special case for all-zeros padding, test it too.
class MaxUnpool3dModulePad0(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, 2, 2, 3], torch.float32, True),
([-1, -1, 2, 2, 3], torch.int64, True),
]
)
def forward(self, x, indices):
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0))
@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0())
def MaxUnpool3dModulePad0_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 5, 6)
pool = torch.nn.MaxPool3d(
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True
)
output, indices = pool(input)
module.forward(output, indices)