mirror of https://github.com/llvm/torch-mlir
Use upstream shape functions when available (#1952)
There are several ops that have their shape function upstream and had not been updated in Torch-MLIR to use the upstream version. This commit updates those shape function. In addition, TODOs have been added for shape functions that should be upstream but are not.pull/1964/head
parent
158be370d1
commit
a7449785ec
|
@ -6348,56 +6348,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
|
||||
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %2 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__._reduce_along_dim(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %5 = torch.prim.min.self_int %4 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %5, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If %arg2 -> () {\n"
|
||||
" %8 = torch.aten.append.t %2, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %8 = torch.aten.append.t %2, %6 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
|
||||
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
@ -6462,102 +6425,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bmm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %9 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %13 : !torch.list<int>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.baddbmm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: baddbmm only supports 3D tensors\"\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %9 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %13 : !torch.list<int>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.embedding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.int, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
||||
|
@ -7124,23 +6997,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n"
|
||||
" %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %3 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
|
@ -7210,7 +7068,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||
|
@ -7320,98 +7179,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %17 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %10 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %16 = torch.aten.len.t %15 : !torch.list<int> -> !torch.int\n"
|
||||
" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %18 = torch.prim.If %17 -> (!torch.bool) {\n"
|
||||
" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %20 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %18 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %11 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %14 = torch.prim.If %13 -> (!torch.tuple<list<int>, list<int>>) {\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %17 = torch.prim.TupleConstruct %16, %9 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %17 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.prim.TupleConstruct %9, %9 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %15 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" return %14 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
|
@ -7430,54 +7199,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop %3, %true, init() {\n"
|
||||
" ^bb0(%arg5: !torch.int):\n"
|
||||
" %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop %6, %true, init() {\n"
|
||||
" ^bb0(%arg5: !torch.int):\n"
|
||||
" %8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %7 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.prim.If %arg5 -> (!torch.tuple<list<int>, list<int>, list<int>>) {\n"
|
||||
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %5 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" } else {\n"
|
||||
" %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
|
||||
|
|
|
@ -14,6 +14,7 @@ import torch.jit._shape_functions as upstream_shape_functions
|
|||
from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function
|
||||
from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar
|
||||
|
||||
# TODO: upstream this
|
||||
def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int):
|
||||
assert len(weight) == 2
|
||||
assert len(indices) == 1
|
||||
|
@ -343,17 +344,6 @@ def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased
|
|||
def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
|
||||
out: List[int] = []
|
||||
for i, self_dim in enumerate(self):
|
||||
if i == dim:
|
||||
if keepdim:
|
||||
out.append(1)
|
||||
else:
|
||||
out.append(self_dim)
|
||||
return out
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
|
||||
|
@ -364,15 +354,13 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
|||
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
|
||||
])
|
||||
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
if dim is None:
|
||||
return []
|
||||
return _reduce_along_dim(self, dim, keepdim)
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
|
||||
return _reduce_along_dim(self, dim, keepdim)
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||
reduced_shape = _reduce_along_dim(self, dim, keepdim)
|
||||
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return reduced_shape, reduced_shape
|
||||
|
||||
def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
||||
|
@ -393,6 +381,7 @@ def aten〇transpose〇int〡shape(self: List[int], dim0: int, dim1: int) -> Lis
|
|||
def aten〇t〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.transpose(self, 0, 1)
|
||||
|
||||
# TODO: upstream this
|
||||
def aten〇numpy_T〡shape(self: List[int]) -> List[int]:
|
||||
result_shape: List[int] = []
|
||||
for i in self:
|
||||
|
@ -419,22 +408,15 @@ def aten〇addmm〡shape(self: List[int], mat1: List[int], mat2: List[int], beta
|
|||
ErrorInvocation(TensorOfShape(2, 3, 4), TensorOfShape(2, 4)), # RHS is not rank 3.
|
||||
])
|
||||
def aten〇bmm〡shape(self: List[int], mat2: List[int]) -> List[int]:
|
||||
assert len(self) == 3, "bmm only supports 3D tensors"
|
||||
assert len(mat2) == 3, "bmm only supports 3D tensors"
|
||||
assert self[0] == mat2[0], "mismatching batch dimension"
|
||||
assert self[2] == mat2[1], "mismatching contracting dimension"
|
||||
return [self[0], self[1], mat2[2]]
|
||||
return upstream_shape_functions.bmm(self, mat2)
|
||||
|
||||
def aten〇baddbmm〡shape(self: List[int], batch1: List[int], batch2: List[int], beta: float = 1, alpha: float = 1) -> List[int]:
|
||||
assert len(batch1) == 3, "baddbmm only supports 3D tensors"
|
||||
assert len(batch2) == 3, "baddbmm only supports 3D tensors"
|
||||
assert batch1[0] == batch2[0], "mismatching batch dimension"
|
||||
assert batch1[2] == batch2[1], "mismatching contracting dimension"
|
||||
return [batch1[0], batch1[1], batch2[2]]
|
||||
return upstream_shape_functions.bmm(batch1, batch2)
|
||||
|
||||
def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
# TODO: upstream this
|
||||
def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
|
||||
assert len(repeats) >= len(self)
|
||||
ndim = len(repeats)
|
||||
|
@ -801,12 +783,7 @@ def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in
|
|||
ErrorInvocation(TensorOfShape(2, 3), 2, dim=100), # `dim` out of bounds.
|
||||
])
|
||||
def aten〇topk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]:
|
||||
assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
|
||||
# All lists which represent tensor shapes are expected to be the result
|
||||
# of a fresh invocation of `AtenSizeOp`, which allocates a new, unaliased
|
||||
# list. So in-place mutations are ok.
|
||||
self[dim] = k
|
||||
return self, self
|
||||
return upstream_shape_functions.topk(self, k, dim)
|
||||
|
||||
def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
@ -867,13 +844,7 @@ def aten〇convolution_backward_overrideable〡shape(grad_output: List[int], inp
|
|||
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
|
||||
|
||||
def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
# Torch's symbolic shape analysis is a bit looser about optional
|
||||
# arguments than we are, so their batch_norm helper function works
|
||||
# even though the `weight` is not `Optional`.
|
||||
# Upstream is working to make this more consistent.
|
||||
# For now, since this function is so trivial, just write it ourselves.
|
||||
#return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
|
||||
return input
|
||||
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
|
||||
|
||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||
|
@ -918,24 +889,12 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets
|
|||
ErrorInvocation(TensorOfShape(2, 3), LongTensorOfShape(7), None, 1, -100), # Mismatched batch dimension.
|
||||
])
|
||||
def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int) -> Tuple[List[int], List[int]]:
|
||||
# This is taken shamelessly from the meta function in LossNLL.cpp
|
||||
self_dim = len(self)
|
||||
target_dim = len(target)
|
||||
assert 0 < self_dim <= 2
|
||||
assert target_dim <= 1
|
||||
no_batch_dim = self_dim == 1 and target_dim == 0
|
||||
assert no_batch_dim or (self[0] == target[0])
|
||||
n_classes = self[-1]
|
||||
scalar_shape: List[int] = []
|
||||
assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
|
||||
if reduction == 0 and self_dim == 2:
|
||||
return [self[0]], scalar_shape
|
||||
else:
|
||||
return scalar_shape, scalar_shape
|
||||
return upstream_shape_functions.nll_loss_forward(self, target, weight, reduction)
|
||||
|
||||
def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
# TODO: upstream this
|
||||
def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]:
|
||||
if reduction == 0:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
@ -945,14 +904,7 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int =
|
|||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||
])
|
||||
def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||
reduction_shape: List[int] = []
|
||||
num_unreduced_dimensions = len(input) - len(normalized_shape)
|
||||
assert num_unreduced_dimensions >= 0
|
||||
for i in range(num_unreduced_dimensions):
|
||||
reduction_shape.append(input[i])
|
||||
for i in range(num_unreduced_dimensions, len(input)):
|
||||
reduction_shape.append(1)
|
||||
return input, reduction_shape, reduction_shape
|
||||
return upstream_shape_functions.native_layer_norm(input, normalized_shape)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3), None, None, None, None, True, 1e-4, 1e-6), # Training basic case.
|
||||
|
@ -962,9 +914,7 @@ def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[in
|
|||
ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low.
|
||||
])
|
||||
def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||
if training:
|
||||
return input, [input[1]], [input[1]]
|
||||
return input, [0], [0]
|
||||
return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training)
|
||||
|
||||
# TODO: This should be upstreamed.
|
||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||
|
@ -990,6 +940,7 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float
|
|||
def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
|
||||
return pad_shape_fn(self, pad)
|
||||
|
||||
# TODO: upstream this
|
||||
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
||||
assert len(indices) <= len(self), "More indices than dimensions to index"
|
||||
broadcasted_shape: List[int] = []
|
||||
|
|
Loading…
Reference in New Issue