diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index cf2d2beee..af5268a06 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -270,6 +270,8 @@ private: "`keepdim` must be a constant bool"); SmallVector dimList; + bool isNoneOrEmptyDimList = + op.dim().getType().template isa(); if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) { // Fix negative dimensions, if any, before adding to the list. for (int64_t dim : dimList) { @@ -278,13 +280,16 @@ private: if (isValidDim(dim, inputType.getRank())) opInfo.dimSet.insert(dim); } - } else if (op.dim().getType().template isa()) { + if (dimList.empty()) + isNoneOrEmptyDimList = true; + } else if (!isNoneOrEmptyDimList) { + return rewriter.notifyMatchFailure( + op, "`dim` argument must be a constant int list or None"); + } + if (isNoneOrEmptyDimList) { // If no dimensions were specified, reduce along all dimensions for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); - } else { - return rewriter.notifyMatchFailure( - op, "`dim` argument must be a constant int list or None"); } return opInfo; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b4bc870bd..f80e92df3 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1012,6 +1012,7 @@ public: PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.self(); + unsigned inputRank = getTensorRank(input); Value dimList = op.dim(); Value keepDim = op.keepdim(); Value dtype = op.dtype(); @@ -1036,12 +1037,18 @@ public: loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, input, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); + Value productDimSize; + // Case: Reduce along all dims. + if (dimListConstruct.elements().empty() && inputRank != 0) { + productDimSize = rewriter.create(loc, input); + } else { + productDimSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + for (Value dim : dimListConstruct.elements()) { + Value dimSize = rewriter.create(loc, input, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, productDimSize); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 5693acea8..06e409932 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5566,9 +5566,25 @@ module { } func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + torch.prim.If.yield %arg1 : !torch.list + } + %3 = torch.derefine %none : !torch.none to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { %true = torch.constant.bool true @@ -5657,25 +5673,55 @@ module { return %1 : !torch.tuple, list> } func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + torch.prim.If.yield %arg1 : !torch.list + } + %3 = torch.derefine %arg3 : !torch.optional to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none + %int0 = torch.constant.int 0 %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool } else { - %2 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool } - return %1 : !torch.list + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %3 = torch.derefine %arg3 : !torch.optional to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index bff4cb7b9..a4c3011bc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -490,6 +490,8 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: return [] def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: + if len(dim)==0: + dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, None) def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: @@ -533,13 +535,14 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ return reduced_shape, reduced_shape def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + if len(dim)==0: + dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None: - return upstream_shape_functions.mean_dim(self, [], keepdim, dtype) - else: - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + if dim is None or len(dim)==0: + dim = list(range(len(self))) + return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) def aten〇permute(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index be1ed5078..3fcf29534 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -106,6 +106,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListEmptyDimModule()) +def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 01c07d87a..4d27e5263 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -180,6 +180,26 @@ class MeanDimNegativeModule(torch.nn.Module): def MeanDimNegativeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, dim=[]) + + +@register_test_case(module_factory=lambda: MeanDimEmptyDimModule()) +def MeanDimEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + # ============================================================================== class VarUnbiasedModule(torch.nn.Module): @@ -410,7 +430,7 @@ def VarDimNegativeModule_basic(module, tu: TestUtils): # ============================================================================== -class VarDimKeepDimFalseModule(torch.nn.Module): +class VarDimEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -421,11 +441,11 @@ class VarDimKeepDimFalseModule(torch.nn.Module): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=False) + return torch.ops.aten.var(x, dim=[], keepdim=False) -@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule()) -def VarDimKeepDimFalseModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: VarDimEmptyDimModule()) +def VarDimEmptyDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5))