From 85f383ce0b6855592c7d440f744d9f41d73aa21e Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 17 Aug 2022 00:11:04 -0400 Subject: [PATCH] Bump the shape lib to match the upstream functions currently in PyTorch (#1236) Bumps the shape library: - Updates the function signature for aten.arange.start_step - upstream_shape_functions.mean_dim -> upstream_shape_functions.sum_mean_dim --- build_tools/update_shape_lib.sh | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 377 ++++++++++++------ .../jit_ir/build_tools/shape_lib_gen.py | 14 +- 3 files changed, 269 insertions(+), 124 deletions(-) diff --git a/build_tools/update_shape_lib.sh b/build_tools/update_shape_lib.sh index 16526569c..b2d619d34 100755 --- a/build_tools/update_shape_lib.sh +++ b/build_tools/update_shape_lib.sh @@ -28,5 +28,5 @@ fi PYTHONPATH="${pypath}" python \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \ - --pytorch_op_extensions=${ext_module} \ + --pytorch_op_extensions=${ext_module:-""} \ --torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index eaa848a0e..6cdffe12d 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3494,6 +3494,137 @@ module { %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> return %6 : !torch.tuple, list, list> } + func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int + %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %8, %true, init() { + ^bb0(%arg9: !torch.int): + %9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %10 = torch.prim.If %1 -> (!torch.int) { + %11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %12 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + torch.prim.If %arg6 -> () { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } else { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list, !torch.int -> !torch.int + %23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %3 : !torch.list + } + func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list { + %true = torch.constant.bool true + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool + %3 = torch.prim.If %2 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool + %5 = torch.prim.If %4 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %9 = torch.prim.ListConstruct : () -> !torch.list + %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list + %12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list + %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %14, %true, init() { + ^bb0(%arg8: !torch.int): + %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %16 = torch.prim.If %7 -> (!torch.int) { + %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %33 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int + %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int + %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int + %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int + %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int + %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int + %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %9 : !torch.list + } func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { %none = torch.constant.none %str = torch.constant.str "AssertionError: " @@ -4369,70 +4500,78 @@ module { } return %6 : !torch.list } - func.func @__torch__.torch.jit._shape_functions.mean_dim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { - %none = torch.constant.none + func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 %false = torch.constant.bool false %true = torch.constant.bool true + %none = torch.constant.none %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { + %1 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.list) { + %4 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.If.yield %4 : !torch.list + } else { + %4 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %4 : !torch.list + } + %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + torch.prim.Loop %3, %true, init() { ^bb0(%arg4: !torch.int): - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.prim.Loop %2, %true, init(%false) { + %4 = torch.aten.len.t %2 : !torch.list -> !torch.int + %5 = torch.prim.Loop %4, %true, init(%false) { ^bb0(%arg5: !torch.int, %arg6: !torch.bool): - %4 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.int) { + %6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + %9 = torch.prim.If %8 -> (!torch.int) { torch.prim.If.yield %int1 : !torch.int } else { - torch.prim.If.yield %5 : !torch.int + torch.prim.If.yield %7 : !torch.int } - %8 = torch.aten.neg.int %7 : !torch.int -> !torch.int - %9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.lt.int %4, %8 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { + %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int + %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int + %12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool + %13 = torch.prim.If %12 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %17 = torch.aten.gt.int %4, %9 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool + %19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %19 : !torch.bool } - %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool - torch.prim.If %12 -> () { + %14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool + torch.prim.If %14 -> () { torch.prim.If.yield } else { torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } - %13 = torch.aten.lt.int %4, %int0 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.int) { - %17 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %17 : !torch.int + %15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + %16 = torch.prim.If %15 -> (!torch.int) { + %19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int + torch.prim.If.yield %19 : !torch.int } else { - torch.prim.If.yield %4 : !torch.int + torch.prim.If.yield %6 : !torch.int } - %15 = torch.aten.eq.int %arg4, %14 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { + %17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool + %18 = torch.prim.If %17 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { torch.prim.If.yield %arg6 : !torch.bool } - torch.prim.Loop.condition %true, iter(%16 : !torch.bool) + torch.prim.Loop.condition %true, iter(%18 : !torch.bool) } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - torch.prim.If %3 -> () { + torch.prim.If %5 -> () { torch.prim.If %arg2 -> () { - %4 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } else { torch.prim.If.yield } torch.prim.If.yield } else { - %4 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %0, %4 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.append.t %0, %6 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } torch.prim.Loop.condition %true, iter() @@ -4442,10 +4581,10 @@ module { func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> { %false = torch.constant.bool false %true = torch.constant.bool true + %none = torch.constant.none %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %str = torch.constant.str "AssertionError: " - %none = torch.constant.none %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -5441,10 +5580,6 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5524,6 +5659,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5580,27 +5719,28 @@ module { %1 = torch.prim.If %0 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %7 = torch.aten.len.t %6 : !torch.list -> !torch.int + %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %8 : !torch.bool } %2 = torch.prim.If %1 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { + %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %7 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %6, %true, init() { ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list + torch.prim.If.yield %7 : !torch.list } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %6 : !torch.list } - %3 = torch.derefine %none : !torch.none to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %3 = torch.derefine %2 : !torch.list to !torch.optional> + %4 = torch.derefine %none : !torch.none to !torch.any + %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %5 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { %true = torch.constant.bool true @@ -5610,27 +5750,28 @@ module { %1 = torch.prim.If %0 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %7 = torch.aten.len.t %6 : !torch.list -> !torch.int + %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %8 : !torch.bool } %2 = torch.prim.If %1 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { + %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %7 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %6, %true, init() { ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list + torch.prim.If.yield %7 : !torch.list } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %6 : !torch.list } - %3 = torch.derefine %none : !torch.none to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %3 = torch.derefine %2 : !torch.list to !torch.optional> + %4 = torch.derefine %none : !torch.none to !torch.any + %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %5 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { %0 = torch.prim.ListConstruct : () -> !torch.list @@ -5644,27 +5785,28 @@ module { %1 = torch.prim.If %0 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %7 = torch.aten.len.t %6 : !torch.list -> !torch.int + %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %8 : !torch.bool } %2 = torch.prim.If %1 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { + %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %7 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %6, %true, init() { ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list + torch.prim.If.yield %7 : !torch.list } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %6 : !torch.list } - %3 = torch.derefine %none : !torch.none to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %3 = torch.derefine %2 : !torch.list to !torch.optional> + %4 = torch.derefine %none : !torch.none to !torch.any + %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %5 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { %none = torch.constant.none @@ -5726,27 +5868,28 @@ module { %1 = torch.prim.If %0 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %7 = torch.aten.len.t %6 : !torch.list -> !torch.int + %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %8 : !torch.bool } %2 = torch.prim.If %1 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { + %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %7 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %6, %true, init() { ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list + torch.prim.If.yield %7 : !torch.list } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %6 : !torch.list } - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %3 = torch.derefine %2 : !torch.list to !torch.optional> + %4 = torch.derefine %arg3 : !torch.optional to !torch.any + %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %5 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { %true = torch.constant.bool true @@ -5756,27 +5899,28 @@ module { %1 = torch.prim.If %0 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %7 = torch.aten.len.t %6 : !torch.list -> !torch.int + %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %8 : !torch.bool } %2 = torch.prim.If %1 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { + %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %7 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %6, %true, init() { ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.append.t %7, %arg4 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list + torch.prim.If.yield %7 : !torch.list } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list + %6 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %6 : !torch.list } - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %4 : !torch.list + %3 = torch.derefine %2 : !torch.list to !torch.optional> + %4 = torch.derefine %arg3 : !torch.optional to !torch.any + %5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %5 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list @@ -6889,21 +7033,22 @@ module { %none = torch.constant.none %0 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool %1 = torch.prim.If %0 -> (!torch.list) { - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %4, %true, init() { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { ^bb0(%arg5: !torch.int): - %6 = torch.aten.append.t %5, %arg5 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.append.t %6, %arg5 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %5 : !torch.list + torch.prim.If.yield %6 : !torch.list } else { - %4 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - torch.prim.If.yield %4 : !torch.list + %5 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list } - %2 = torch.derefine %arg4 : !torch.optional to !torch.any - %3 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %1, %arg3, %2) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %3 : !torch.list + %2 = torch.derefine %1 : !torch.list to !torch.optional> + %3 = torch.derefine %arg4 : !torch.optional to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } } )mlir"); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index d990404b2..cd07633dc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -526,12 +526,12 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: if dim is None or len(dim)==0: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, None) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: if dim is None or len(dim)==0: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, None) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: return [] @@ -539,7 +539,7 @@ def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: if dim is None or len(dim)==0: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, None) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self)) @@ -576,12 +576,12 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: if dim is None or len(dim)==0: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: if dim is None or len(dim)==0: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇permute(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -813,7 +813,7 @@ def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]: def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self -def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: +def aten〇arange〇start_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) def aten〇arange〇start(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: @@ -1171,7 +1171,7 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: if dim is None: dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) # ============================================================================== # Shape library generator main().