diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 4a9c3da52..81b3b5484 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -45,6 +45,10 @@ Value castIntToIndex(OpBuilder &b, Location loc, Value v); Value castIndexToInt64(OpBuilder &b, Location loc, Value idx); +SmallVector +castIntVectorToIndexVector(OpBuilder &b, Location loc, + SmallVectorImpl &intValues); + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 756516dd9..d1828c474 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -32,15 +32,21 @@ using namespace mlir::torch::Torch; template static LogicalResult checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, - bool &ceilMode, - SmallVectorImpl &kernelSizeInts, + TypeConverter *typeConverter, bool &ceilMode, + SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { // Pattern match against the op's original operands, because otherwise we // will get the lowered version of the operands which is harder to pattern // match. - if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) - return rewriter.notifyMatchFailure(op, "only support kernel size ints"); + SmallVector kernelSizeTorchInt; + if (!getListConstructElements(op.kernel_size(), kernelSizeTorchInt)) { + return rewriter.notifyMatchFailure(op, + "unimplemented: the kernel size is " + "not constructed from ListConstruct"); + } + kernelSizeIntValues = getTypeConvertedValues( + rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt); if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) @@ -58,7 +64,7 @@ template static LogicalResult createPoolingOp( Operation *op, ConversionPatternRewriter &rewriter, Value self, bool supportFPInputOnly, bool ceilMode, - SmallVectorImpl &kernelSizeInts, + SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { @@ -88,8 +94,6 @@ static LogicalResult createPoolingOp( getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); - SmallVector kernelSizeIntValues = - getAsConstantIntValues(rewriter, loc, kernelSizeInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); @@ -108,7 +112,7 @@ static LogicalResult createPoolingOp( auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); Value windowTensor = rewriter.create( - loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts), + loc, castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues), elementType); result = rewriter @@ -130,6 +134,7 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.self(); int64_t selfRank = self.getType().cast().getRank(); // TODO: Add support for 3D inputs. @@ -137,14 +142,15 @@ public: return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); - SmallVector kernelSizeInts, strideInts, paddingInts, - dilationInts; + bool ceilMode; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - bool ceilMode; if (failed(checkAndGetPoolingParameters( - op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) + op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); @@ -158,7 +164,7 @@ public: Value maxPool2d, paddedInput; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/false, ceilMode, - kernelSizeInts, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Type newResultType = getTypeConverter()->convertType(op.getType()); @@ -200,6 +206,7 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); + TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.self(); RankedTensorType selfType = self.getType().cast(); Type elementType = selfType.getElementType(); @@ -213,14 +220,15 @@ public: return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); - SmallVector kernelSizeInts, strideInts, paddingInts, - dilationInts; + bool ceilMode; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - bool ceilMode; if (failed(checkAndGetPoolingParameters( - op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) + op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // `maxpool2d` contains the result of maxpool2d operation over the input. @@ -233,7 +241,7 @@ public: SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/false, ceilMode, - kernelSizeInts, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); @@ -244,7 +252,7 @@ public: indicesRankedTensorType.getElementType(), cstMinusOne); SmallVector kernelSize = - getAsConstantIndexValues(rewriter, loc, kernelSizeInts); + castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); SmallVector padding = getAsConstantIndexValues(rewriter, loc, paddingInts); SmallVector dilation = @@ -344,105 +352,6 @@ public: }; } // namespace -namespace { -class ConvertAtenAdaptiveAvgPool2dOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenAdaptiveAvgPool2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - Value input = adaptor.self(); /* in form of N*C*H*W */ - RankedTensorType inputType = input.getType().cast(); - Type elementType = inputType.getElementType(); - if (!elementType.isa()) - return op.emitError("unimplemented: non-floating point type"); - - auto inputRank = inputType.getRank(); - if (inputRank != 4) - return rewriter.notifyMatchFailure(op, "input should be rank 4"); - - SmallVector expects{1, 1}; - // Pattern match against the op's original operands, because otherwise we - // will get the lowered version of the operands which is harder to pattern - // match. - if (!isConstantIntListMatching(op.output_size(), expects)) - return rewriter.notifyMatchFailure( - op, "only support output_size with H and W both equal to constant 1"); - - Value N = getDimOp(rewriter, loc, input, 0); - Value C = getDimOp(rewriter, loc, input, 1); - Value initTensor = rewriter.create( - loc, ValueRange{N, C}, elementType); - Value c0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - Value initTensor0 = - rewriter.create(loc, c0, initTensor).getResult(0); - - SmallVector ncExprs; - ncExprs.push_back(mlir::getAffineDimExpr(0, context)); - ncExprs.push_back(mlir::getAffineDimExpr(1, context)); - auto ncIndexingMap = AffineMap::get( - /*dimCount=*/4, - /*symbolCount=*/0, ncExprs, context); - SmallVector indexingMaps = { - rewriter.getMultiDimIdentityMap(4), // input - ncIndexingMap, // output - }; - SmallVector iteratorTypesSum{"parallel", "parallel", - "reduction", "reduction"}; - Value sumPool2d = rewriter - .create( - loc, initTensor0.getType(), input, initTensor0, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypesSum, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], sum = args[1]; - Value result = rewriter.create( - loc, sum, input); - b.create(loc, result); - }) - .getResult(0); - - // Calculate H*W so that avg can be got from sum / (H*W) - Value H = getDimOp(rewriter, loc, input, 2); - Value W = getDimOp(rewriter, loc, input, 3); - auto castIndexToInt = [&](Value v) { - return rewriter.create( - loc, IntegerType::get(context, 64), v); - }; - Value HtimesW = rewriter.create(loc, castIndexToInt(H), - castIndexToInt(W)); - Value HtimesWf = - rewriter.create(loc, elementType, HtimesW); - - Value c1Index = rewriter.create(loc, /*value=*/1); - Value outputTensor = rewriter.create( - loc, ValueRange{N, C, c1Index, c1Index}, elementType); - SmallVector indexingMapsAvg{ - ncIndexingMap, rewriter.getMultiDimIdentityMap(4)}; - SmallVector iteratorTypesAvg(4, "parallel"); - Value avgPool2d = - rewriter - .create( - loc, outputTensor.getType(), sumPool2d, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg = b.create(loc, args[0], HtimesWf); - b.create(loc, avg); - }) - .getResult(0); - - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, avgPool2d); - return success(); - } -}; -} // namespace - namespace { class ConvertAtenAvgPool2dOp : public OpConversionPattern { public: @@ -453,6 +362,7 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); + TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.self(); Type inputElementType = @@ -461,11 +371,12 @@ public: Type resultElementType = resultType.cast().getElementType(); - SmallVector dilationInts{1, 1}; - SmallVector kernelSizeInts, strideInts, paddingInts; bool ceilMode; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts{1, 1}; if (failed(checkAndGetPoolingParameters( - op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) + op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // TODO: Add support for count_include_pad equal to `False`. @@ -484,13 +395,11 @@ public: SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/true, ceilMode, - kernelSizeInts, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, strideInts, paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d"); - SmallVector kernelSizeIntValues = - getAsConstantIntValues(rewriter, loc, kernelSizeInts); Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); Value divisor = op.divisor_override().getType().isa() @@ -534,8 +443,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 0c7cc96a8..9f989033e 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -143,6 +143,15 @@ Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { return b.create(loc, b.getI64Type(), idx); } +SmallVector +castIntVectorToIndexVector(OpBuilder &b, Location loc, + SmallVectorImpl &intValues) { + SmallVector indexValues; + for (Value v : intValues) + indexValues.push_back(castIntToIndex(b, loc, v)); + return indexValues; +} + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { return b.createOrFold(loc, v, dim); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6418c4e9a..a136d53f2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1551,7 +1551,8 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern { // Note that this is the same decomposition as in AOTAutograd // https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104 namespace { -class DecomposeAten_ReshapeAliasOp : public OpRewritePattern { +class DecomposeAten_ReshapeAliasOp + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op, PatternRewriter &rewriter) const override { @@ -1775,6 +1776,116 @@ public: }; } // namespace +namespace { +// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. +// +// For AdaptiveAvgPool2d op, when the input size is an integer multiple of +// output size the kernel_size, stride and padding is calculated as follows: +// strideH = inH // outH +// strideW = inH // outH +// kernelH = inH - [(outH - 1) * strideH] +// kernelW = inW - [(outW - 1) * strideW] +// paddingH = 0, paddingW = 0 +// +// For the special case, when the output size is one for all dimensions, +// the kernel size is same as the input size. +class DecomposeAtenAdaptiveAvgPool2dOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.self(); + int64_t rank = getTensorRank(input); + SmallVector inputHW; + Value dimH = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 2)); + inputHW.push_back( + /*inH=*/rewriter.create(loc, input, dimH)); + Value dimW = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 1)); + inputHW.push_back( + /*inW=*/rewriter.create(loc, input, dimW)); + + Value outputShape = op.output_size(); + SmallVector outputShapeSizesTorchInt; + getListConstructElements(outputShape, outputShapeSizesTorchInt); + + // TODO: Add support for cases other than: + // 1.) inH == outH and inW == outW. + // 2.) outH == outW == 1 + bool unitOutputSize = true; + for (Value outShape : outputShapeSizesTorchInt) { + int64_t outShapeInt; + if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) { + return rewriter.notifyMatchFailure( + op, "output size is expected to be a constant"); + } + if (outShapeInt != 1) { + unitOutputSize = false; + break; + } + } + + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = rewriter.create(loc, false); + Value constantTrue = rewriter.create(loc, true); + Value constantNone = rewriter.create(loc); + SmallVector kernelSize; + + for (unsigned i = 0; i < inputHW.size(); i++) { + if (unitOutputSize) { + BaseTensorType inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize + ? inputHW[i] + : rewriter.create( + loc, rewriter.getI64IntegerAttr( + inputShape[rank - 2 + i]))); + } else { + Value cond = rewriter.create(loc, inputHW[i], + outputShapeSizesTorchInt[i]); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + + Value outMinusOne = rewriter.create( + loc, outputShapeSizesTorchInt[i], constantOne); + kernelSize.push_back( + rewriter.create(loc, inputHW[i], outMinusOne)); + } + } + + Value kernelSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); + // Currently we only support cases where input size is equal to the output + // size or unit output size. For the former case, stride is always equal to + // one and for the latter the stride value doesn't matter, since the kernel + // size is same as the input size. Therfore, keeping the stride as one for + // the latter case as well for the ease of implementation. + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne, constantOne}); + Value paddingSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantZero, constantZero}); + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue, + /*divisor_override=*/constantNone); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -1910,6 +2021,8 @@ class DecomposeComplexOpsPass patterns.add(context); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 1f3086efa..26a18b0df 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -12,7 +12,72 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== -class AdaptiveAvgPool2dModule(torch.nn.Module): +class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7)) + + @export + @annotate_args([ + None, + ([1, 512, 7, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeStaticModule()) +def AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) + + +class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7)) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) + + +class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + @export + @annotate_args([ + None, + ([1, 512, 7, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeStaticModule()) +def AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) + + +class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() @@ -27,9 +92,10 @@ class AdaptiveAvgPool2dModule(torch.nn.Module): return self.aap2d(x) -@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule()) -def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 3, 8, 9)) +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index be3575904..6e237c838 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -11,12 +11,14 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?, %int7 = torch.constant.int 7 %int8 = torch.constant.int 8 %false = torch.constant.bool false + // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 + // CHECK: %[[C2:.*]] = torch_c.to_i64 %int2 // CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor + // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index + // CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index + // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[T1]], %[[T2]]] : tensor // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index c2f69eeaf..4366a542f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -949,3 +949,37 @@ func.func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torc %0 = torch.aten.to.dtype_layout %arg0, %int7, %int0, %none, %none, %false, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> return %0 : !torch.vtensor<[?,?],f64> } + +// ----- +// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[CST3:.*]] = torch.constant.int 3 +// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size" +// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size" +// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.adaptive_avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int7 = torch.constant.int 7 + %output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +}