diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 594e66752..a4feae8ac 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7054,6 +7054,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [ }]; } +def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 560ac95b1..8ad5cefc0 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, // Max pooling if (isa(op)) { + AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } +// AtenMaxPool1dWithIndicesOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + + auto outValTy = + cast(getTypeConverter()->convertType(op.getType(0))); + auto outIdxTy = + cast(getTypeConverter()->convertType(op.getType(1))); + + if (inputRank <= 1) { + return op.emitError( + "max_pooling1d only supports inputs with rank higher than 1"); + } + + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - 1); + std::copy(dilation.begin(), dilation.end(), + stablehloDilation.begin() + inputRank - 1); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - 1); + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + DenseI64ArrayAttr baseDilations; + + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + // no need to reshape here for max_pool_1d. Need to make sure the iota + // dimension. dim=inputRank-2 or dim=inputRank-1? + auto indexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + inputShapeTensor, static_cast(inputRank - 1)) + .getResult(); + Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, + windowDimensions, windowStrides, baseDilations, windowDilations, pad); + + // add block. + Block &block = reduceWindowOp.getBody().emplaceBlock(); + auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); + auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); + auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + auto *firstValArg = block.args_begin(); + auto *firstIdxArg = std::next(firstValArg); + auto *secondValArg = std::next(firstIdxArg); + auto *secondIdxArg = std::next(secondValArg); + + stablehlo::ComparisonTypeAttr compareTypeAttr; + if (isa(inputTy.getElementType())) { + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::FLOAT); + } else if (isa(inputTy.getElementType())) { + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + } + + stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::EQ); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value compareGeResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + Value retValResult = rewriter.create( + op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + + // Get smaller index if compared values are equal. + Value compareEqResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareEqDirectionAttr, compareTypeAttr); + Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, + *secondIdxArg); + Value idxWithGeVal = rewriter.create( + op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = rewriter.create( + op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + rewriter.create( + op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} + // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( #define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) + INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp); INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); #undef INSERT_ATEN_POOLING_PATTERN diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a4eb6dcff..ec4336352 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7047,6 +7047,85 @@ public: }; } // namespace +namespace { +// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices` +// op. +class DecomposeAtenAdaptiveMaxPool1dOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getSelf(); + std::optional maybeRank = getTensorRank(input); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected input to have a rank"); + } + unsigned rank = *maybeRank; + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 1)); + Value inputSize = rewriter.create(loc, input, sizeDim); + + Value outputShape = op.getOutputSize(); + SmallVector outputShapeSizesTorchInt; + getListConstructElements(outputShape, outputShapeSizesTorchInt); + Value outputSize = outputShapeSizesTorchInt[0]; + + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = rewriter.create(loc, false); + + int64_t outputSizeInt; + if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { + return rewriter.notifyMatchFailure( + op, "the output size of adaptive_max_pool1d must be a constant int"); + } + + SmallVector kernelSize; + if (outputSizeInt == 1) { + BaseTensorType inputTensorType = cast(input.getType()); + ArrayRef inputShape = inputTensorType.getSizes(); + kernelSize.push_back( + inputShape[rank - 1] == kUnknownSize + ? inputSize + : rewriter.create( + loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + } else { + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } + kernelSize.push_back(constantOne); + } + + Value kernelSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + Value paddingSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantZero}); + Value dialationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + + rewriter.replaceOpWithNewOp( + op, op.getType(0), op.getType(1), input, kernelSizeList, strideList, + paddingSizeList, dialationList, + /*ceil_mode=*/constantFalse); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. @@ -9510,6 +9589,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e487c12a3..3dcd711d2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -513,6 +513,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dDimOneStatic_basic", "AdaptiveMaxPool1dStatic_basic", "AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index dd68a43bb..17ef8bc5b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -616,6 +616,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit( + "aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + ) emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 6d36c6909..cb8adf3f7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1537,6 +1537,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) +class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False) + + @export + @annotate_args([None, ([1, 512, 7], torch.float32, True)]) + def forward(self, x): + return self.amp1d(x) + + +@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic()) +def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + + # AdaptiveMaxPool2d