diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 1c3de1107..ae1717bc2 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -11,6 +11,7 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" @@ -583,6 +584,258 @@ public: }; } // namespace +namespace { +// Max unpooling operation, takes result of max_pooling op and indices and +// tries to reconstructs original pooling input by filling out values by either +// values from self or zero. +// Upstream CPU implementation use parallel loop over the indices array to fill +// out tensor but such approach requires random access writes, which is tricky +// to represent in linalg. +// Instead we are using a different method: we are mapping each input/index +// value to multiple output values via affine maps in linalg.generic, then, +// inside the body of generic, we compute out index and compare it with expected +// index we got from input, returning either input or zero. +// This method only works if we have non-overlapping pooling windows. +// In case of overlap (e.g. kernel_size=2, stride=1) we need to map many-to-many +// input to output values and do a reduction. To construct such mapping we need +// to know original Kernel size, but it doesn't encoded in aten op. We cannot +// reconstruct kernel_size either as such reconstruction is ambiguous (e.g. for +// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3). +// What worse, without knowing kernel size we cannot even reliably detect such +// cases and this conversion will just return invalid values. +class ConvertAtenMaxUnpool3dOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + Value self = adaptor.getSelf(); + auto selfType = cast(self.getType()); + + ArrayRef inputSize = selfType.getShape().take_back(3); + if (ShapedType::isDynamicShape(inputSize)) + return rewriter.notifyMatchFailure(op, + "input type must be of static shape"); + + Value indices = adaptor.getIndices(); + auto indicesType = cast(indices.getType()); + if (inputSize != indicesType.getShape().take_back(3)) + return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); + + auto resType = typeConverter->convertType(op.getType()); + if (!resType) + return rewriter.notifyMatchFailure(op, "invalid result type"); + + ArrayRef inferredOutSize = resType.getShape().take_back(3); + if (ShapedType::isDynamicShape(inferredOutSize)) + return rewriter.notifyMatchFailure(op, + "output type must be of static shape"); + + { + SmallVector output; + if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(output))) + return rewriter.notifyMatchFailure(op, + "only support constant int output"); + + if (inferredOutSize != ArrayRef(output)) + return rewriter.notifyMatchFailure(op, "Invalid output size"); + } + SmallVector stride; + SmallVector padding; + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) + return rewriter.notifyMatchFailure(op, + "only support constant int strides"); + + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding))) + return rewriter.notifyMatchFailure(op, + "only support constant int padding"); + + // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" + // (padding.size() == 6). + if (stride.size() != 3 || padding.size() != 3) + return rewriter.notifyMatchFailure( + op, "stride and padding must be of size 3"); + + int64_t outRank = resType.getRank(); + int64_t NC = outRank - 3; + + for (auto &&[inDim, outDim, str, pad] : + llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { + // Kernel size computation is ambiguous, this formula will return the + // biggest possible kernel size. As there is no way to know actual kernel + // size we have to treat it conservatively and always bail if kernel size + // potentially bigger than stride. + int64_t kernelSize = outDim - (inDim - 1) * str + 2 * pad; + if (kernelSize > str) + return rewriter.notifyMatchFailure( + op, "potential pooling windows overlapping is detected, this case " + "is not supported yet"); + } + + Type indexType = rewriter.getIndexType(); + SmallVector outSizePadded; + for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { + if (int64_t(i) < NC) { + outSizePadded.emplace_back( + rewriter.create(loc, self, i)); + continue; + } + int64_t pad = padding[i - NC]; + + outSizePadded.emplace_back( + rewriter.create(loc, size + pad)); + } + + auto ceilDiv = [](int64_t v1, int64_t v2) -> int64_t { + return (v1 + v2 - 1) / v2; + }; + + // In case if input tensor size is not divisible by stride + // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) + // pad self and indices tensors to avoid out of bounds access. + SmallVector expectedInputShape = + llvm::to_vector(resType.getShape().drop_back(3)); + for (auto &&[str, pad, resSize] : + llvm::zip_equal(stride, padding, inferredOutSize)) + expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); + + if (expectedInputShape != selfType.getShape()) { + // TODO: this is probably expensive, and it may be possible to solve by + // cleverly constructing affine maps for the next linalg.generic op, + // but I'm not smart enough to figure this out. + + SmallVector low(outRank, 0); + SmallVector high(NC, 0); + for (auto &&[inpSize, outSize] : llvm::zip_equal( + inputSize, ArrayRef(expectedInputShape).take_back(3))) { + high.emplace_back(outSize - inpSize); + } + + // Pad the indices tensor with a value which cannot appear in real data + // (-1) so it will never match. In this case we can pad self with any + // value, as it will never affect the output. + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(selfType.getElementType())); + Value invalidIdx = rewriter.create( + loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); + self = + torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); + indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, + high, invalidIdx); + } + + Value init = rewriter.create( + loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); + + SmallVector inputExprs; + SmallVector outputExprs; + for (auto i : llvm::seq(0, outRank)) { + AffineExpr dim = rewriter.getAffineDimExpr(i); + if (i < NC) { + inputExprs.emplace_back(dim); + } else { + int64_t j = i - NC; + inputExprs.emplace_back(dim.floorDiv(stride[j])); + } + outputExprs.emplace_back(dim); + } + + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, inputExprs, outputExprs}, rewriter.getContext()); + + SmallVector iteratorTypes( + outRank, utils::IteratorType::parallel); + + auto computeIndex = [&](OpBuilder &b, Location loc) -> Value { + // Next linalg.generic uses identity mapping for the unpooled tensor, + // compute linear index for output element, which we will the compare with + // values which came from indices tensor. + Value ret; + for (auto i : llvm::seq(NC, outRank)) { + Value idx = b.create(loc, i); + // If pool input was padded, adjust indices so they start at 0 in the + // non-padded area. Indices outside non-padded area will make no sense, + // but it doesnt matter as we will cut the padded area later by + // extract_slice. + int64_t pad = padding[i - NC]; + if (pad != 0) { + Value padVal = b.create(loc, pad); + idx = b.create(loc, idx, padVal); + } + + if (!ret) { + ret = idx; + } else { + Value size = + b.create(loc, resType.getShape()[i]); + ret = b.create(loc, ret, size); + ret = b.create(loc, ret, idx); + } + } + return ret; + }; + + auto builder = [&](OpBuilder &b, Location loc, ValueRange args) { + // Compute current output linear index and compare it with the value + // from indices arg. + Value input = args[0]; + Value zero = b.create( + loc, rewriter.getZeroAttr(input.getType())); + Value index = b.create(loc, indexType, args[1]); + Value currentIndex = computeIndex(b, loc); + Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, + currentIndex); + Value out = b.create(loc, cmp, input, zero); + b.create(loc, out); + }; + + Value result = + rewriter + .create(loc, + /*resultTensorTypes=*/init.getType(), + /*inputs=*/ValueRange({self, indices}), + /*outputs=*/init, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, builder) + .getResult(0); + + if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { + // MaxPool input was padded, unpad it by taking the slice. + SmallVector offsetVals(NC, rewriter.getI64IntegerAttr(0)); + for (int64_t pad : padding) + offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad)); + + SmallVector sizeVals; + for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) { + if (!ShapedType::isDynamic(dim)) { + sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim)); + continue; + } + + sizeVals.emplace_back(rewriter.create(loc, self, i)); + } + SmallVector stridesVals(outRank, + rewriter.getI64IntegerAttr(1)); + result = rewriter.create(loc, result, offsetVals, + sizeVals, stridesVals); + } + + if (result.getType() != resType) + result = rewriter.create(loc, resType, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -1275,6 +1528,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns .add>( diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 24f8648cc..190469c3a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8159,6 +8159,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n" +" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Input be of rank 4 or 5\"\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.list) {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %11, %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" return %10 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" @@ -11687,6 +11751,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 85258d6f8..e2ad3310e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2471,6 +2471,8 @@ ONNX_XFAIL_SET = { "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", "MaxPool3dModule_basic", + "MaxUnpool3dModule_basic", + "MaxUnpool3dModulePad0_basic", "MeanDimEmptyDimModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6e44bc127..cabe40e80 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1043,6 +1043,15 @@ def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[in def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: return self +def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: + assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5" + assert (len(output_size) == 3), "output_size must have 3 elements" + assert (len(self) == len(indices)), "Input and indices must be of the same rank" + if len(self) == 5: + return [self[0], self[1], output_size[0], output_size[1], output_size[2]] + else: + return [self[0], output_size[0], output_size[1], output_size[2]] + def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size @@ -3054,6 +3063,10 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) def aten〇adaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 1de40096c..ae26a7cef 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1698,3 +1698,61 @@ class AdaptiveMaxPool3dStaticWithIndices(torch.nn.Module): @register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) + + +# ============================================================================== + + +class MaxUnpool3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, 2, 2, 4], torch.float32, True), + ([-1, -1, 2, 2, 4], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1)) + + +@register_test_case(module_factory=lambda: MaxUnpool3dModule()) +def MaxUnpool3dModule_basic(module, tu: TestUtils): + input = tu.rand(2, 2, 4, 5, 6) + pool = torch.nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True + ) + output, indices = pool(input) + + module.forward(output, indices) + + +# We have a special case for all-zeros padding, test it too. +class MaxUnpool3dModulePad0(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, 2, 2, 3], torch.float32, True), + ([-1, -1, 2, 2, 3], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0)) + + +@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0()) +def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): + input = tu.rand(2, 2, 4, 5, 6) + pool = torch.nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True + ) + output, indices = pool(input) + + module.forward(output, indices)