diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index da5ee799a..50fa1f8e6 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -277,6 +277,10 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, if (isa(op) || isa(op)) return b.create(loc, b.getZeroAttr(elementType)); + if (isa(op)) { + return b.create(loc, b.getBoolAttr(true)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -357,6 +361,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto ord = b.create(loc, twoAttr); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -447,6 +456,9 @@ private: if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + if (auto allOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter); + return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } @@ -535,6 +547,9 @@ private: !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); + if (isa(op) && elemType.isa() && + elemType.getIntOrFloatBitWidth() == 8) + return rewriter.notifyMatchFailure(op, "uint8 is not supported"); // No checks for all other reduction operations return success(); } @@ -610,6 +625,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 57ece8cfd..4290ce23c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7006,6 +7006,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" " %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" @@ -11809,6 +11814,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 922b207a2..dadd87a15 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -543,6 +543,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: + return upstream_shape_functions.argmax(self, dim, keepdim) + def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape @@ -3766,6 +3769,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim return self_dtype return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 8418d1ae8..ea2ff1609 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -316,6 +316,78 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAllDimEmpty(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=0, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimEmpty()) +def ReduceAllDimEmpty_basic(module, tu: TestUtils): + module.forward(torch.tensor([])) + +# ============================================================================== + +class ReduceAllDimFloat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimFloat()) +def ReduceAllDimFloat_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]])) + +# ============================================================================== + +class ReduceAllDimInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimInt()) +def ReduceAllDimInt_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32)) + +# ============================================================================== + +class ReduceAllDimBool(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimBool()) +def ReduceAllDimBool_basic(module, tu: TestUtils): + module.forward(torch.tensor([[True, False, True], [True, True, True]])) + +# ============================================================================== + class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__()