From f07f7d20f97cef3e45d08c595348906fecaa7f11 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 18 Aug 2022 08:23:43 -0700 Subject: [PATCH] Clean up shape functions that use `sum_mean_dim` (#1217) I recently fixed the handling of the `dim` argument in `sum_mean_dim` (https://github.com/pytorch/pytorch/commit/59fccab85775da7a0ecf33bda241f81eade3ad4b). Therefore, the checks that the `dim` input is `None` or `[]` are no longer needed. --- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 397 +++--------------- .../jit_ir/build_tools/shape_lib_gen.py | 12 - 2 files changed, 68 insertions(+), 341 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 6cdffe12d..ff42c75bd 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3494,137 +3494,6 @@ module { %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> return %6 : !torch.tuple, list, list> } - func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.prim.ListConstruct : () -> !torch.list - %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list - %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list - %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg9: !torch.int): - %9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %10 = torch.prim.If %1 -> (!torch.int) { - %11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %12 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - torch.prim.If %arg6 -> () { - %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int - %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int - %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list, !torch.int -> !torch.int - %23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %9 = torch.prim.ListConstruct : () -> !torch.list - %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list - %12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list - %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %14, %true, init() { - ^bb0(%arg8: !torch.int): - %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %16 = torch.prim.If %7 -> (!torch.int) { - %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %33 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int - %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int - %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int - %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int - %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int - %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %9 : !torch.list - } func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { %none = torch.constant.none %str = torch.constant.str "AssertionError: " @@ -4502,76 +4371,90 @@ module { } func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %false = torch.constant.bool false %true = torch.constant.bool true %none = torch.constant.none + %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.list) { - %4 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %4 : !torch.list + %2 = torch.prim.If %1 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool } else { - %4 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %4 : !torch.list + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool } - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %3, %true, init() { + %3 = torch.prim.If %2 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + torch.prim.Loop %4, %true, init() { ^bb0(%arg4: !torch.int): - %4 = torch.aten.len.t %2 : !torch.list -> !torch.int - %5 = torch.prim.Loop %4, %true, init(%false) { + %5 = torch.aten.len.t %3 : !torch.list -> !torch.int + %6 = torch.prim.Loop %5, %true, init(%false) { ^bb0(%arg5: !torch.int, %arg6: !torch.bool): - %6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { + %7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool + %10 = torch.prim.If %9 -> (!torch.int) { torch.prim.If.yield %int1 : !torch.int } else { - torch.prim.If.yield %7 : !torch.int + torch.prim.If.yield %8 : !torch.int } - %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int - %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { + %11 = torch.aten.neg.int %10 : !torch.int -> !torch.int + %12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool + %14 = torch.prim.If %13 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool + %20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %20 : !torch.bool } - %14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool - torch.prim.If %14 -> () { + %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool + torch.prim.If %15 -> () { torch.prim.If.yield } else { torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } - %15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - %19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %19 : !torch.int + %16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + %17 = torch.prim.If %16 -> (!torch.int) { + %20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + torch.prim.If.yield %20 : !torch.int } else { - torch.prim.If.yield %6 : !torch.int + torch.prim.If.yield %7 : !torch.int } - %17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.bool) { + %18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool + %19 = torch.prim.If %18 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { torch.prim.If.yield %arg6 : !torch.bool } - torch.prim.Loop.condition %true, iter(%18 : !torch.bool) + torch.prim.Loop.condition %true, iter(%19 : !torch.bool) } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - torch.prim.If %5 -> () { + torch.prim.If %6 -> () { torch.prim.If %arg2 -> () { - %6 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } else { torch.prim.If.yield } torch.prim.If.yield } else { - %6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.append.t %0, %6 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } torch.prim.Loop.condition %true, iter() @@ -4582,8 +4465,8 @@ module { %false = torch.constant.bool false %true = torch.constant.bool true %none = torch.constant.none - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %str = torch.constant.str "AssertionError: " %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list %1 = torch.prim.ListConstruct : () -> !torch.list @@ -5712,101 +5595,26 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { - %true = torch.constant.bool true %none = torch.constant.none - %int0 = torch.constant.int 0 - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %7 = torch.aten.len.t %6 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } - %2 = torch.prim.If %1 -> (!torch.list) { - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg4: !torch.int): - %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.derefine %2 : !torch.list to !torch.optional> - %4 = torch.derefine %none : !torch.none to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %5 : !torch.list + %0 = torch.derefine %none : !torch.none to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { - %true = torch.constant.bool true %none = torch.constant.none - %int0 = torch.constant.int 0 - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %7 = torch.aten.len.t %6 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } - %2 = torch.prim.If %1 -> (!torch.list) { - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg4: !torch.int): - %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.derefine %2 : !torch.list to !torch.optional> - %4 = torch.derefine %none : !torch.none to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %5 : !torch.list + %0 = torch.derefine %none : !torch.none to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { - %true = torch.constant.bool true %none = torch.constant.none - %int0 = torch.constant.int 0 - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %7 = torch.aten.len.t %6 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } - %2 = torch.prim.If %1 -> (!torch.list) { - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg4: !torch.int): - %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.derefine %2 : !torch.list to !torch.optional> - %4 = torch.derefine %none : !torch.none to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %5 : !torch.list + %0 = torch.derefine %none : !torch.none to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { %none = torch.constant.none @@ -5861,66 +5669,14 @@ module { return %1 : !torch.tuple, list> } func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %int0 = torch.constant.int 0 - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %7 = torch.aten.len.t %6 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } - %2 = torch.prim.If %1 -> (!torch.list) { - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg4: !torch.int): - %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.derefine %2 : !torch.list to !torch.optional> - %4 = torch.derefine %arg3 : !torch.optional to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %5 : !torch.list + %0 = torch.derefine %arg3 : !torch.optional to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %int0 = torch.constant.int 0 - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %7 = torch.aten.len.t %6 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } - %2 = torch.prim.If %1 -> (!torch.list) { - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg4: !torch.int): - %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.derefine %2 : !torch.list to !torch.optional> - %4 = torch.derefine %arg3 : !torch.optional to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %5 : !torch.list + %0 = torch.derefine %arg3 : !torch.optional to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list @@ -7029,26 +6785,9 @@ module { return %none : !torch.none } func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %0 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { - ^bb0(%arg5: !torch.int): - %7 = torch.aten.append.t %6, %arg5 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list - } else { - %5 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list - } - %2 = torch.derefine %1 : !torch.list to !torch.optional> - %3 = torch.derefine %arg4 : !torch.optional to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %0 = torch.derefine %arg4 : !torch.optional to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } } )mlir"); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index cd07633dc..7ec3b67f8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -524,21 +524,15 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: return [] def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: - if dim is None or len(dim)==0: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: - if dim is None or len(dim)==0: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: return [] def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: - if dim is None or len(dim)==0: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): @@ -574,13 +568,9 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ return reduced_shape, reduced_shape def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None or len(dim)==0: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None or len(dim)==0: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇permute(self: List[int], dims: List[int]) -> List[int]: @@ -1169,8 +1159,6 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen return [hacky_get_unknown_dimension_size()] def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None: - dim = list(range(len(self))) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) # ==============================================================================