diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 9432b06fa..bf2773488 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -379,6 +379,7 @@ STABLEHLO_PASS_SET = { "ConstantBoolParameterModule_basic", "MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddSizeIntModule_basic", "AddSizeIntNegDimModule_basic", @@ -781,6 +782,7 @@ STABLEHLO_PASS_SET = { "ReshapeExpandModule_basic", "RollModule_basic", "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "BaddbmmStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", @@ -1197,6 +1199,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1239,6 +1243,8 @@ LTC_XFAIL_SET = { "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 991f65f9e..8bff52697 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5323,6 +5323,30 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ }]; } +def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b2fd823b8..485d52d15 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6952,6 +6952,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %23 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.adaptive_avg_pool1d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8266,6 +8327,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2012d1c94..ca977eb8d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3326,6 +3326,85 @@ public: }; } // namespace +namespace { +// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. + +// The logic of this decomposition is totally same with +// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two +// cases are supported: +// 1. inputSize = outputSize +// 2. outputSize = 1 +class DecomposeAtenAdaptiveAvgPool1dOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getSelf(); + std::optional maybeRank = getTensorRank(input); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected input to have a rank"); + } + unsigned rank = *maybeRank; + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 1)); + Value inputSize = rewriter.create(loc, input, sizeDim); + + Value outputShape = op.getOutputSize(); + SmallVector outputShapeSizesTorchInt; + getListConstructElements(outputShape, outputShapeSizesTorchInt); + Value outputSize = outputShapeSizesTorchInt[0]; + + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = rewriter.create(loc, false); + Value constantTrue = rewriter.create(loc, true); + + int64_t outputSizeInt; + if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { + return rewriter.notifyMatchFailure( + op, "the output size of adaptive_pool_1d must be a constant int"); + } + + SmallVector kernelSize; + if (outputSizeInt == 1) { + BaseTensorType inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + kernelSize.push_back( + inputShape[rank - 1] == kUnknownSize + ? inputSize + : rewriter.create( + loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + } else { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + kernelSize.push_back(constantOne); + } + + Value kernelSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + Value paddingSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantZero}); + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. // @@ -4800,6 +4879,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bddce9359..cc4230d65 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -446,6 +446,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 92170fa89..41a14f876 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -565,9 +565,28 @@ def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padd else: return [nbatch, nInputPlane, outputLength] +# TODO: This should be upstreamed. +# See https://github.com/pytorch/pytorch/pull/76889 for an example. +def adaptive_avg_pool1d(self: List[int], out: List[int]): + assert len(out) == 1 + assert len(self) == 2 or len(self) == 3 + + for i in range(len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(len(self) - 1): + shape.append(self[i]) + shape.append(out[0]) + + return shape + def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) +def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: + return adaptive_avg_pool1d(self, output_size) + def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) @@ -1407,6 +1426,11 @@ def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) +def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index fb438fdd8..b51f8eb58 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -418,6 +418,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 84ed83a41..dd18545b0 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -771,4 +771,88 @@ class AvgPool1dStaticModule(torch.nn.Module): @register_test_case(module_factory=lambda: AvgPool1dStaticModule()) def AvgPool1dStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 4, 20, high=100)) \ No newline at end of file + module.forward(tu.randint(2, 4, 20, high=100)) + + +# ============================================================================== + + +class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + + @export + @annotate_args([ + None, + ([1, 512, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeStaticModule()) +def AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + + @export + @annotate_args([ + None, + ([1, 512, 7], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeStaticModule()) +def AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule()) +def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) \ No newline at end of file