Use upstream shape functions when available (#1952)

There are several ops that have their shape function upstream and had
not been updated in Torch-MLIR to use the upstream version. This
commit updates those shape function. In addition, TODOs have been
added for shape functions that should be upstream but are not.
pull/1964/head
Ramiro Leal-Cavazos 2023-03-24 09:13:43 -07:00 committed by GitHub
parent 158be370d1
commit a7449785ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 361 deletions

View File

@ -6348,56 +6348,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._reduce_along_dim(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = torch.prim.min.self_int %4 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %5, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If %arg2 -> () {\n"
" %8 = torch.aten.append.t %2, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" %8 = torch.aten.append.t %2, %6 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
@ -6462,102 +6425,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bmm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n"
" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n"
" %int3 = torch.constant.int 3\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %13 : !torch.list<int>\n"
" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.baddbmm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n"
" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: baddbmm only supports 3D tensors\"\n"
" %int3 = torch.constant.int 3\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %13 : !torch.list<int>\n"
" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.embedding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.int, !torch.bool, !torch.bool) -> !torch.list<int>\n"
@ -7124,23 +6997,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n"
" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n"
" %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>>\n"
" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
@ -7210,7 +7068,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
@ -7320,98 +7179,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %int-1 = torch.constant.int -1\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %17 : !torch.bool\n"
" }\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %10 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %16 = torch.aten.len.t %15 : !torch.list<int> -> !torch.int\n"
" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %18 = torch.prim.If %17 -> (!torch.bool) {\n"
" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %20 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %18 : !torch.bool\n"
" }\n"
" torch.prim.If %11 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %14 = torch.prim.If %13 -> (!torch.tuple<list<int>, list<int>>) {\n"
" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list<int>\n"
" %17 = torch.prim.TupleConstruct %16, %9 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %17 : !torch.tuple<list<int>, list<int>>\n"
" } else {\n"
" %15 = torch.prim.TupleConstruct %9, %9 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %15 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" return %14 : !torch.tuple<list<int>, list<int>>\n"
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
@ -7430,54 +7199,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n"
" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg5: !torch.int):\n"
" %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %6, %true, init() {\n"
" ^bb0(%arg5: !torch.int):\n"
" %8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %7 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.If %arg5 -> (!torch.tuple<list<int>, list<int>, list<int>>) {\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>\n"
" %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" torch.prim.If.yield %5 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" } else {\n"
" %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" torch.prim.If.yield %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"

View File

@ -14,6 +14,7 @@ import torch.jit._shape_functions as upstream_shape_functions
from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function
from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar
# TODO: upstream this
def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int):
assert len(weight) == 2
assert len(indices) == 1
@ -343,17 +344,6 @@ def atenstddim〡shape(self: List[int], dim: Optional[List[int]], unbiased
def atenstdcorrection〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
out: List[int] = []
for i, self_dim in enumerate(self):
if i == dim:
if keepdim:
out.append(1)
else:
out.append(self_dim)
return out
@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
@ -364,15 +354,13 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
])
def atenargmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
if dim is None:
return []
return _reduce_along_dim(self, dim, keepdim)
return upstream_shape_functions.argmax(self, dim, keepdim)
def atenanydim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
return _reduce_along_dim(self, dim, keepdim)
return upstream_shape_functions.argmax(self, dim, keepdim)
def atenmaxdim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
reduced_shape = _reduce_along_dim(self, dim, keepdim)
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
def atenamax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
@ -393,6 +381,7 @@ def atentransposeint〡shape(self: List[int], dim0: int, dim1: int) -> Lis
def atent〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.transpose(self, 0, 1)
# TODO: upstream this
def atennumpy_T〡shape(self: List[int]) -> List[int]:
result_shape: List[int] = []
for i in self:
@ -419,22 +408,15 @@ def atenaddmm〡shape(self: List[int], mat1: List[int], mat2: List[int], beta
ErrorInvocation(TensorOfShape(2, 3, 4), TensorOfShape(2, 4)), # RHS is not rank 3.
])
def atenbmm〡shape(self: List[int], mat2: List[int]) -> List[int]:
assert len(self) == 3, "bmm only supports 3D tensors"
assert len(mat2) == 3, "bmm only supports 3D tensors"
assert self[0] == mat2[0], "mismatching batch dimension"
assert self[2] == mat2[1], "mismatching contracting dimension"
return [self[0], self[1], mat2[2]]
return upstream_shape_functions.bmm(self, mat2)
def atenbaddbmm〡shape(self: List[int], batch1: List[int], batch2: List[int], beta: float = 1, alpha: float = 1) -> List[int]:
assert len(batch1) == 3, "baddbmm only supports 3D tensors"
assert len(batch2) == 3, "baddbmm only supports 3D tensors"
assert batch1[0] == batch2[0], "mismatching batch dimension"
assert batch1[2] == batch2[1], "mismatching contracting dimension"
return [batch1[0], batch1[1], batch2[2]]
return upstream_shape_functions.bmm(batch1, batch2)
def atenembedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
# TODO: upstream this
def atenrepeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
assert len(repeats) >= len(self)
ndim = len(repeats)
@ -801,12 +783,7 @@ def atenaddcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in
ErrorInvocation(TensorOfShape(2, 3), 2, dim=100), # `dim` out of bounds.
])
def atentopk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]:
assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
# All lists which represent tensor shapes are expected to be the result
# of a fresh invocation of `AtenSizeOp`, which allocates a new, unaliased
# list. So in-place mutations are ok.
self[dim] = k
return self, self
return upstream_shape_functions.topk(self, k, dim)
def atenconv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
@ -867,13 +844,7 @@ def atenconvolution_backward_overrideable〡shape(grad_output: List[int], inp
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
def atenbatch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
# Torch's symbolic shape analysis is a bit looser about optional
# arguments than we are, so their batch_norm helper function works
# even though the `weight` is not `Optional`.
# Upstream is working to make this more consistent.
# For now, since this function is so trivial, just write it ourselves.
#return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
return input
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)
@ -918,24 +889,12 @@ def aten_embedding_bag〡shape(weight: List[int], indices: List[int], offsets
ErrorInvocation(TensorOfShape(2, 3), LongTensorOfShape(7), None, 1, -100), # Mismatched batch dimension.
])
def atennll_loss_forward〡shape(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int) -> Tuple[List[int], List[int]]:
# This is taken shamelessly from the meta function in LossNLL.cpp
self_dim = len(self)
target_dim = len(target)
assert 0 < self_dim <= 2
assert target_dim <= 1
no_batch_dim = self_dim == 1 and target_dim == 0
assert no_batch_dim or (self[0] == target[0])
n_classes = self[-1]
scalar_shape: List[int] = []
assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
if reduction == 0 and self_dim == 2:
return [self[0]], scalar_shape
else:
return scalar_shape, scalar_shape
return upstream_shape_functions.nll_loss_forward(self, target, weight, reduction)
def atennll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
# TODO: upstream this
def atenmse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]:
if reduction == 0:
return upstream_shape_functions.unary(self)
@ -945,14 +904,7 @@ def atenmse_loss〡shape(self: List[int], target: List[int], reduction: int =
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
])
def atennative_layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
reduction_shape: List[int] = []
num_unreduced_dimensions = len(input) - len(normalized_shape)
assert num_unreduced_dimensions >= 0
for i in range(num_unreduced_dimensions):
reduction_shape.append(input[i])
for i in range(num_unreduced_dimensions, len(input)):
reduction_shape.append(1)
return input, reduction_shape, reduction_shape
return upstream_shape_functions.native_layer_norm(input, normalized_shape)
@check_shape_function([
Invocation(TensorOfShape(2, 3), None, None, None, None, True, 1e-4, 1e-6), # Training basic case.
@ -962,9 +914,7 @@ def atennative_layer_norm〡shape(input: List[int], normalized_shape: List[in
ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low.
])
def atennative_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]:
if training:
return input, [input[1]], [input[1]]
return input, [0], [0]
return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training)
# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
@ -990,6 +940,7 @@ def atenconstant_pad_nd〡shape(self: List[int], pad: List[int], value: float
def atenpad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
return pad_shape_fn(self, pad)
# TODO: upstream this
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
assert len(indices) <= len(self), "More indices than dimensions to index"
broadcasted_shape: List[int] = []