From a7449785ec5fa7f0093cf2eb719f6aa8f9fe3507 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 24 Mar 2023 09:13:43 -0700 Subject: [PATCH] Use upstream shape functions when available (#1952) There are several ops that have their shape function upstream and had not been updated in Torch-MLIR to use the upstream version. This commit updates those shape function. In addition, TODOs have been added for shape functions that should be upstream but are not. --- .../Transforms/AbstractInterpLibrary.cpp | 320 ++---------------- .../build_tools/abstract_interp_lib_gen.py | 79 +---- 2 files changed, 38 insertions(+), 361 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5b91bd653..3150e3bb0 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6348,56 +6348,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %none = torch.constant.none\n" -" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.list) {\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" torch.prim.If.yield %2 : !torch.list\n" -" } else {\n" -" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" -" torch.prim.If.yield %3 : !torch.list\n" -" }\n" -" return %1 : !torch.list\n" -" }\n" -" func.func @__torch__._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" -" %true = torch.constant.bool true\n" -" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list\n" -" %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int\n" -" torch.prim.Loop %5, %true, init() {\n" -" ^bb0(%arg3: !torch.int):\n" -" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If %arg2 -> () {\n" -" %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" } else {\n" -" %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" return %2 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" -" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" -" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %1 : !torch.tuple, list>\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" @@ -6462,102 +6425,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.bmm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" -" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %13 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.baddbmm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list {\n" -" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" -" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: baddbmm only supports 3D tensors\"\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %13 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.embedding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.int, !torch.bool, !torch.bool) -> !torch.list\n" @@ -7124,23 +6997,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n" -" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" -" %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %3 : !torch.tuple, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" @@ -7210,7 +7068,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" -" return %arg0 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" @@ -7320,98 +7179,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" -" %int-1 = torch.constant.int -1\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %false = torch.constant.bool false\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %17 : !torch.bool\n" -" }\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.prim.ListConstruct : () -> !torch.list\n" -" %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" -" %16 = torch.aten.len.t %15 : !torch.list -> !torch.int\n" -" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %18 = torch.prim.If %17 -> (!torch.bool) {\n" -" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %20 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If.yield %18 : !torch.bool\n" -" }\n" -" torch.prim.If %11 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.tuple, list>) {\n" -" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" -" %17 = torch.prim.TupleConstruct %16, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %17 : !torch.tuple, list>\n" -" } else {\n" -" %15 = torch.prim.TupleConstruct %9, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %15 : !torch.tuple, list>\n" -" }\n" -" return %14 : !torch.tuple, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" @@ -7430,54 +7199,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" -" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop %3, %true, init() {\n" -" ^bb0(%arg5: !torch.int):\n" -" %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.append.t %0, %8 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %6, %true, init() {\n" -" ^bb0(%arg5: !torch.int):\n" -" %8 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" -" return %7 : !torch.tuple, list, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list, list>\n" +" return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) {\n" -" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list\n" -" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list\n" -" %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" -" torch.prim.If.yield %5 : !torch.tuple, list, list>\n" -" } else {\n" -" %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" -" torch.prim.If.yield %3 : !torch.tuple, list, list>\n" -" }\n" +" %0 = call @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool) -> !torch.tuple, list, list>\n" " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index cfa21fe15..8d4d3cf34 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -14,6 +14,7 @@ import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar +# TODO: upstream this def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): assert len(weight) == 2 assert len(indices) == 1 @@ -343,17 +344,6 @@ def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) -def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): - dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self)) - out: List[int] = [] - for i, self_dim in enumerate(self): - if i == dim: - if keepdim: - out.append(1) - else: - out.append(self_dim) - return out - @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. @@ -364,15 +354,13 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. ]) def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: - if dim is None: - return [] - return _reduce_along_dim(self, dim, keepdim) + return upstream_shape_functions.argmax(self, dim, keepdim) def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: - return _reduce_along_dim(self, dim, keepdim) + return upstream_shape_functions.argmax(self, dim, keepdim) def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: - reduced_shape = _reduce_along_dim(self, dim, keepdim) + reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: @@ -393,6 +381,7 @@ def aten〇transpose〇int〡shape(self: List[int], dim0: int, dim1: int) -> Lis def aten〇t〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.transpose(self, 0, 1) +# TODO: upstream this def aten〇numpy_T〡shape(self: List[int]) -> List[int]: result_shape: List[int] = [] for i in self: @@ -419,22 +408,15 @@ def aten〇addmm〡shape(self: List[int], mat1: List[int], mat2: List[int], beta ErrorInvocation(TensorOfShape(2, 3, 4), TensorOfShape(2, 4)), # RHS is not rank 3. ]) def aten〇bmm〡shape(self: List[int], mat2: List[int]) -> List[int]: - assert len(self) == 3, "bmm only supports 3D tensors" - assert len(mat2) == 3, "bmm only supports 3D tensors" - assert self[0] == mat2[0], "mismatching batch dimension" - assert self[2] == mat2[1], "mismatching contracting dimension" - return [self[0], self[1], mat2[2]] + return upstream_shape_functions.bmm(self, mat2) def aten〇baddbmm〡shape(self: List[int], batch1: List[int], batch2: List[int], beta: float = 1, alpha: float = 1) -> List[int]: - assert len(batch1) == 3, "baddbmm only supports 3D tensors" - assert len(batch2) == 3, "baddbmm only supports 3D tensors" - assert batch1[0] == batch2[0], "mismatching batch dimension" - assert batch1[2] == batch2[1], "mismatching contracting dimension" - return [batch1[0], batch1[1], batch2[2]] + return upstream_shape_functions.bmm(batch1, batch2) def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) +# TODO: upstream this def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: assert len(repeats) >= len(self) ndim = len(repeats) @@ -801,12 +783,7 @@ def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in ErrorInvocation(TensorOfShape(2, 3), 2, dim=100), # `dim` out of bounds. ]) def aten〇topk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]: - assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}" - # All lists which represent tensor shapes are expected to be the result - # of a fresh invocation of `AtenSizeOp`, which allocates a new, unaliased - # list. So in-place mutations are ok. - self[dim] = k - return self, self + return upstream_shape_functions.topk(self, k, dim) def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) @@ -867,13 +844,7 @@ def aten〇convolution_backward_overrideable〡shape(grad_output: List[int], inp return upstream_shape_functions.conv_backwards(grad_output, input, weight, None) def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: - # Torch's symbolic shape analysis is a bit looser about optional - # arguments than we are, so their batch_norm helper function works - # even though the `weight` is not `Optional`. - # Upstream is working to make this more consistent. - # For now, since this function is so trivial, just write it ourselves. - #return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) - return input + return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -918,24 +889,12 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets ErrorInvocation(TensorOfShape(2, 3), LongTensorOfShape(7), None, 1, -100), # Mismatched batch dimension. ]) def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int) -> Tuple[List[int], List[int]]: - # This is taken shamelessly from the meta function in LossNLL.cpp - self_dim = len(self) - target_dim = len(target) - assert 0 < self_dim <= 2 - assert target_dim <= 1 - no_batch_dim = self_dim == 1 and target_dim == 0 - assert no_batch_dim or (self[0] == target[0]) - n_classes = self[-1] - scalar_shape: List[int] = [] - assert weight is None or (len(weight) == 1 and weight[0] == n_classes) - if reduction == 0 and self_dim == 2: - return [self[0]], scalar_shape - else: - return scalar_shape, scalar_shape + return upstream_shape_functions.nll_loss_forward(self, target, weight, reduction) def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +# TODO: upstream this def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: if reduction == 0: return upstream_shape_functions.unary(self) @@ -945,14 +904,7 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]: - reduction_shape: List[int] = [] - num_unreduced_dimensions = len(input) - len(normalized_shape) - assert num_unreduced_dimensions >= 0 - for i in range(num_unreduced_dimensions): - reduction_shape.append(input[i]) - for i in range(num_unreduced_dimensions, len(input)): - reduction_shape.append(1) - return input, reduction_shape, reduction_shape + return upstream_shape_functions.native_layer_norm(input, normalized_shape) @check_shape_function([ Invocation(TensorOfShape(2, 3), None, None, None, None, True, 1e-4, 1e-6), # Training basic case. @@ -962,9 +914,7 @@ def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[in ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low. ]) def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]: - if training: - return input, [input[1]], [input[1]] - return input, [0], [0] + return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training) # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. @@ -990,6 +940,7 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) +# TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" broadcasted_shape: List[int] = []