diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8326d11c8..f4fc06b51 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -524,6 +524,7 @@ TOSA_PASS_SET = { "SquareModule_basic", "MaxPool2dStaticModule_basic", "ResNet18StaticModule_basic", + "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", "PermuteModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5c61f9f27..3af9fb096 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6756,6 +6756,31 @@ def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ }]; } +def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 59d3fe74f..f68fc7a96 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -167,6 +167,55 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, return sub; } +namespace { +/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the +/// number of dimensions across which the max needs to be computed. +/// Eg: +/// INPUT: +/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) +/// +/// OUTPUT: +/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 +/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 +/// final_output = aten.max.dim(input_2, 0, keepdim) #3 +/// +/// NOTE: We iterate over, in reverse order, every dimension included in `dim` +/// of the `aten.amax` op and create an `aten.amax.dim` op. +/// Input tensor to the next `aten.amax.dim` op is thus the output of the +/// previous `aten.amax.dim` op. +class DecomposeAtenAmaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAmaxOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector dims; + if (!matchPattern(op.dim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, + "non-const dim parameter unsupported"); + + bool keepDim; + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for keepDim"); + + Value input = op.self(); + std::sort(dims.begin(), dims.end()); + // For every dimension included in `dim` of the op, iterated over in + // reverse order, we create a call to aten.max.dim. + for (int64_t i = dims.size() - 1; i >= 0; i--) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dims[i])); + // The input to the next invocation of aten.max.dim is the output of the + // previous aten.max.dim op. + input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); + } + rewriter.replaceOp(op, input); + return success(); + } +}; +} // end namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -3364,6 +3413,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b0d961ced..96e9e78a2 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1006,9 +1006,9 @@ void TypeAnalysis::visitOperation(Operation *op, getDtypeOrDefault(mean.getContext(), mean.dtype(), defaultDtype); visitReductionAlongAllDimsOp(mean, dtype, operands); return; - } else if (auto max = dyn_cast(op)) { + } else if (isa(op)) { Type dtype = operands[0]->getValue().dtype; - visitReductionAlongAllDimsOp(max, dtype, operands); + visitReductionAlongAllDimsOp(op, dtype, operands); return; } else if (isa(op)) { diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index babc6dd18..e60c31bf8 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5842,6 +5842,13 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %1 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index a3d4b11cf..72aaf857f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -589,6 +589,9 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ reduced_shape = _reduce_along_dim(self, dim, keepdim) return reduced_shape, reduced_shape +def aten〇amax(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e71c5ad6f..819e1f770 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -478,6 +478,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::max : (Tensor) -> (Tensor)") emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::amax : (Tensor, int[], bool) -> (Tensor)") emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 8c9225869..2029727ca 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -462,6 +462,78 @@ def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAmaxSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.amax(a, 1) + +@register_test_case(module_factory=lambda: ReduceAmaxSingleDim()) +def ReduceAmaxSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + +# ============================================================================== + +class ReduceAmaxMultiDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.amax(a, (0, 2)) + +@register_test_case(module_factory=lambda: ReduceAmaxMultiDim()) +def ReduceAmaxMultiDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + +# ============================================================================== + +class ReduceAmaxOutOfOrderDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.amax(a, (2, 1, 3)) + +@register_test_case(module_factory=lambda: ReduceAmaxOutOfOrderDim()) +def ReduceAmaxOutOfOrderDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6, high=100)) + +# ============================================================================== + +class ReduceAmaxKeepDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.amax(a, (0, 2), keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAmaxKeepDim()) +def ReduceAmaxKeepDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + +# ============================================================================== + class ReduceL1NormModule(torch.nn.Module): def __init__(self): super().__init__()