mirror of https://github.com/llvm/torch-mlir
`max_unpool3d` linalg lowering (#3536)
An attempt of `aten.max_unpool3d` to linalg lowering. There are known issues with this implementation (see comment in code).pull/3576/head
parent
f1c74e1431
commit
8bd1b9751f
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
|
@ -583,6 +584,258 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Max unpooling operation, takes result of max_pooling op and indices and
|
||||
// tries to reconstructs original pooling input by filling out values by either
|
||||
// values from self or zero.
|
||||
// Upstream CPU implementation use parallel loop over the indices array to fill
|
||||
// out tensor but such approach requires random access writes, which is tricky
|
||||
// to represent in linalg.
|
||||
// Instead we are using a different method: we are mapping each input/index
|
||||
// value to multiple output values via affine maps in linalg.generic, then,
|
||||
// inside the body of generic, we compute out index and compare it with expected
|
||||
// index we got from input, returning either input or zero.
|
||||
// This method only works if we have non-overlapping pooling windows.
|
||||
// In case of overlap (e.g. kernel_size=2, stride=1) we need to map many-to-many
|
||||
// input to output values and do a reduction. To construct such mapping we need
|
||||
// to know original Kernel size, but it doesn't encoded in aten op. We cannot
|
||||
// reconstruct kernel_size either as such reconstruction is ambiguous (e.g. for
|
||||
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
|
||||
// What worse, without knowing kernel size we cannot even reliably detect such
|
||||
// cases and this conversion will just return invalid values.
|
||||
class ConvertAtenMaxUnpool3dOp final
|
||||
: public OpConversionPattern<AtenMaxUnpool3dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfType = cast<RankedTensorType>(self.getType());
|
||||
|
||||
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
|
||||
if (ShapedType::isDynamicShape(inputSize))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"input type must be of static shape");
|
||||
|
||||
Value indices = adaptor.getIndices();
|
||||
auto indicesType = cast<RankedTensorType>(indices.getType());
|
||||
if (inputSize != indicesType.getShape().take_back(3))
|
||||
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
|
||||
|
||||
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
|
||||
if (!resType)
|
||||
return rewriter.notifyMatchFailure(op, "invalid result type");
|
||||
|
||||
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
|
||||
if (ShapedType::isDynamicShape(inferredOutSize))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"output type must be of static shape");
|
||||
|
||||
{
|
||||
SmallVector<int64_t> output;
|
||||
if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(output)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int output");
|
||||
|
||||
if (inferredOutSize != ArrayRef(output))
|
||||
return rewriter.notifyMatchFailure(op, "Invalid output size");
|
||||
}
|
||||
SmallVector<int64_t> stride;
|
||||
SmallVector<int64_t> padding;
|
||||
|
||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int strides");
|
||||
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int padding");
|
||||
|
||||
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
|
||||
// (padding.size() == 6).
|
||||
if (stride.size() != 3 || padding.size() != 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "stride and padding must be of size 3");
|
||||
|
||||
int64_t outRank = resType.getRank();
|
||||
int64_t NC = outRank - 3;
|
||||
|
||||
for (auto &&[inDim, outDim, str, pad] :
|
||||
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
|
||||
// Kernel size computation is ambiguous, this formula will return the
|
||||
// biggest possible kernel size. As there is no way to know actual kernel
|
||||
// size we have to treat it conservatively and always bail if kernel size
|
||||
// potentially bigger than stride.
|
||||
int64_t kernelSize = outDim - (inDim - 1) * str + 2 * pad;
|
||||
if (kernelSize > str)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "potential pooling windows overlapping is detected, this case "
|
||||
"is not supported yet");
|
||||
}
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
SmallVector<Value> outSizePadded;
|
||||
for (auto &&[i, size] : llvm::enumerate(resType.getShape())) {
|
||||
if (int64_t(i) < NC) {
|
||||
outSizePadded.emplace_back(
|
||||
rewriter.create<tensor::DimOp>(loc, self, i));
|
||||
continue;
|
||||
}
|
||||
int64_t pad = padding[i - NC];
|
||||
|
||||
outSizePadded.emplace_back(
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, size + pad));
|
||||
}
|
||||
|
||||
auto ceilDiv = [](int64_t v1, int64_t v2) -> int64_t {
|
||||
return (v1 + v2 - 1) / v2;
|
||||
};
|
||||
|
||||
// In case if input tensor size is not divisible by stride
|
||||
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
|
||||
// pad self and indices tensors to avoid out of bounds access.
|
||||
SmallVector<int64_t> expectedInputShape =
|
||||
llvm::to_vector(resType.getShape().drop_back(3));
|
||||
for (auto &&[str, pad, resSize] :
|
||||
llvm::zip_equal(stride, padding, inferredOutSize))
|
||||
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
|
||||
|
||||
if (expectedInputShape != selfType.getShape()) {
|
||||
// TODO: this is probably expensive, and it may be possible to solve by
|
||||
// cleverly constructing affine maps for the next linalg.generic op,
|
||||
// but I'm not smart enough to figure this out.
|
||||
|
||||
SmallVector<int64_t> low(outRank, 0);
|
||||
SmallVector<int64_t> high(NC, 0);
|
||||
for (auto &&[inpSize, outSize] : llvm::zip_equal(
|
||||
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
|
||||
high.emplace_back(outSize - inpSize);
|
||||
}
|
||||
|
||||
// Pad the indices tensor with a value which cannot appear in real data
|
||||
// (-1) so it will never match. In this case we can pad self with any
|
||||
// value, as it will never affect the output.
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(selfType.getElementType()));
|
||||
Value invalidIdx = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1));
|
||||
self =
|
||||
torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero);
|
||||
indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low,
|
||||
high, invalidIdx);
|
||||
}
|
||||
|
||||
Value init = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outSizePadded), selfType.getElementType());
|
||||
|
||||
SmallVector<AffineExpr> inputExprs;
|
||||
SmallVector<AffineExpr> outputExprs;
|
||||
for (auto i : llvm::seq<int64_t>(0, outRank)) {
|
||||
AffineExpr dim = rewriter.getAffineDimExpr(i);
|
||||
if (i < NC) {
|
||||
inputExprs.emplace_back(dim);
|
||||
} else {
|
||||
int64_t j = i - NC;
|
||||
inputExprs.emplace_back(dim.floorDiv(stride[j]));
|
||||
}
|
||||
outputExprs.emplace_back(dim);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
||||
{inputExprs, inputExprs, outputExprs}, rewriter.getContext());
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outRank, utils::IteratorType::parallel);
|
||||
|
||||
auto computeIndex = [&](OpBuilder &b, Location loc) -> Value {
|
||||
// Next linalg.generic uses identity mapping for the unpooled tensor,
|
||||
// compute linear index for output element, which we will the compare with
|
||||
// values which came from indices tensor.
|
||||
Value ret;
|
||||
for (auto i : llvm::seq<int64_t>(NC, outRank)) {
|
||||
Value idx = b.create<linalg::IndexOp>(loc, i);
|
||||
// If pool input was padded, adjust indices so they start at 0 in the
|
||||
// non-padded area. Indices outside non-padded area will make no sense,
|
||||
// but it doesnt matter as we will cut the padded area later by
|
||||
// extract_slice.
|
||||
int64_t pad = padding[i - NC];
|
||||
if (pad != 0) {
|
||||
Value padVal = b.create<arith::ConstantIndexOp>(loc, pad);
|
||||
idx = b.create<arith::SubIOp>(loc, idx, padVal);
|
||||
}
|
||||
|
||||
if (!ret) {
|
||||
ret = idx;
|
||||
} else {
|
||||
Value size =
|
||||
b.create<arith::ConstantIndexOp>(loc, resType.getShape()[i]);
|
||||
ret = b.create<arith::MulIOp>(loc, ret, size);
|
||||
ret = b.create<arith::AddIOp>(loc, ret, idx);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
auto builder = [&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
// Compute current output linear index and compare it with the value
|
||||
// from indices arg.
|
||||
Value input = args[0];
|
||||
Value zero = b.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(input.getType()));
|
||||
Value index = b.create<arith::IndexCastOp>(loc, indexType, args[1]);
|
||||
Value currentIndex = computeIndex(b, loc);
|
||||
Value cmp = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, index,
|
||||
currentIndex);
|
||||
Value out = b.create<arith::SelectOp>(loc, cmp, input, zero);
|
||||
b.create<linalg::YieldOp>(loc, out);
|
||||
};
|
||||
|
||||
Value result =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(loc,
|
||||
/*resultTensorTypes=*/init.getType(),
|
||||
/*inputs=*/ValueRange({self, indices}),
|
||||
/*outputs=*/init,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes, builder)
|
||||
.getResult(0);
|
||||
|
||||
if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) {
|
||||
// MaxPool input was padded, unpad it by taking the slice.
|
||||
SmallVector<OpFoldResult> offsetVals(NC, rewriter.getI64IntegerAttr(0));
|
||||
for (int64_t pad : padding)
|
||||
offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad));
|
||||
|
||||
SmallVector<OpFoldResult> sizeVals;
|
||||
for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) {
|
||||
if (!ShapedType::isDynamic(dim)) {
|
||||
sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim));
|
||||
continue;
|
||||
}
|
||||
|
||||
sizeVals.emplace_back(rewriter.create<tensor::DimOp>(loc, self, i));
|
||||
}
|
||||
SmallVector<OpFoldResult> stridesVals(outRank,
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
result = rewriter.create<tensor::ExtractSliceOp>(loc, result, offsetVals,
|
||||
sizeVals, stridesVals);
|
||||
}
|
||||
|
||||
if (result.getType() != resType)
|
||||
result = rewriter.create<tensor::CastOp>(loc, resType, result);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename OpTy, typename PoolingOpTy, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
|
||||
|
@ -1275,6 +1528,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
||||
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
|
||||
patterns
|
||||
.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
|
||||
|
|
|
@ -8159,6 +8159,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 4 or 5\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %9 = torch.aten.eq.int %8, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.prim.ListConstruct %11, %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %16 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %15 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -11687,6 +11751,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.tuple<int, int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -2471,6 +2471,8 @@ ONNX_XFAIL_SET = {
|
|||
"MaxPool3dLargeDatadModule_basic",
|
||||
"MaxPool3dModuleRandomSimple_basic",
|
||||
"MaxPool3dModule_basic",
|
||||
"MaxUnpool3dModule_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"Mlp1LayerModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
|
|
|
@ -1043,6 +1043,15 @@ def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[in
|
|||
def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
||||
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
||||
assert (len(output_size) == 3), "output_size must have 3 elements"
|
||||
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
|
||||
if len(self) == 5:
|
||||
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
|
||||
else:
|
||||
return [self[0], output_size[0], output_size[1], output_size[2]]
|
||||
|
||||
def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
||||
return input_size
|
||||
|
||||
|
@ -3054,6 +3063,10 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype, torch.int64
|
||||
|
||||
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2]))
|
||||
def aten〇adaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -1698,3 +1698,61 @@ class AdaptiveMaxPool3dStaticWithIndices(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices())
|
||||
def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 10, 16, 17))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxUnpool3dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, 2, 2, 4], torch.float32, True),
|
||||
([-1, -1, 2, 2, 4], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, indices):
|
||||
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxUnpool3dModule())
|
||||
def MaxUnpool3dModule_basic(module, tu: TestUtils):
|
||||
input = tu.rand(2, 2, 4, 5, 6)
|
||||
pool = torch.nn.MaxPool3d(
|
||||
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True
|
||||
)
|
||||
output, indices = pool(input)
|
||||
|
||||
module.forward(output, indices)
|
||||
|
||||
|
||||
# We have a special case for all-zeros padding, test it too.
|
||||
class MaxUnpool3dModulePad0(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, 2, 2, 3], torch.float32, True),
|
||||
([-1, -1, 2, 2, 3], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, indices):
|
||||
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0())
|
||||
def MaxUnpool3dModulePad0_basic(module, tu: TestUtils):
|
||||
input = tu.rand(2, 2, 4, 5, 6)
|
||||
pool = torch.nn.MaxPool3d(
|
||||
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True
|
||||
)
|
||||
output, indices = pool(input)
|
||||
|
||||
module.forward(output, indices)
|
||||
|
|
Loading…
Reference in New Issue