From 855d267c57f8cba9b5395f7144d1b2b7a6d78351 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Sun, 2 Oct 2022 14:05:53 -0500 Subject: [PATCH] build: update shape library after PyTorch version update (#1449) The auto-update of the PyTorch version broke the Torch-MLIR build because it did not update the shape library. Going forward, we should add the shape library update to the PyTorch version update action. --- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 1001 ++--------------- 1 file changed, 71 insertions(+), 930 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 95b8c703d..9419bf60e 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -23,877 +23,6 @@ StringRef mlir::torch::Torch::getShapeLibrary() { #endif // clang-format off return "module {\n" -" func.func @__torch__.torch._decomp.decompositions.nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %str = torch.constant.str \"Expected a single element grad_output tensor, but got: {}\"\n" -" %str_0 = torch.constant.str \"Expected a tensor of dimension 1 and tensor.size[0] == {} but got: dimension {} and tensor.size[0] == {}\"\n" -" %str_1 = torch.constant.str \"AssertionError: weight tensor should be defined either for all or no classes\"\n" -" %int-1 = torch.constant.int -1\n" -" %str_2 = torch.constant.str \"{} ({} elements)\"\n" -" %str_3 = torch.constant.str \"expected total_weight to be a single element tensor, got: \"\n" -" %str_4 = torch.constant.str \"AssertionError: \"\n" -" %str_5 = torch.constant.str \"size mismatch (got input: {}, target: {})\"\n" -" %true = torch.constant.bool true\n" -" %str_6 = torch.constant.str \"AssertionError: 0D or 1D target tensor expected, multi-target not supported\"\n" -" %none = torch.constant.none\n" -" %str_7 = torch.constant.str \"AssertionError: input tensor should be 1D or 2D\"\n" -" %false = torch.constant.bool false\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.Uninitialized : !torch.optional\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.le.int %int0, %1 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.le.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %5 = torch.aten.le.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %35 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %36 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %37 = torch.aten.eq.int %35, %36 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %37 : !torch.bool\n" -" }\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %36 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" -" %37 = torch.aten.format(%str_5, %35, %36) : !torch.str, !torch.list, !torch.list -> !torch.str\n" -" %38 = torch.aten.add.str %str_4, %37 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %38, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %11 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" -" %36 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %37 = torch.aten.format(%str_2, %35, %36) : !torch.str, !torch.list, !torch.int -> !torch.str\n" -" %38 = torch.prim.TupleConstruct %str_3, %37 : !torch.str, !torch.str -> !torch.tuple\n" -" %39 = torch.aten.str %38 : !torch.tuple -> !torch.str\n" -" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %12 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %35 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %36 = torch.aten.numel %35 : !torch.tensor -> !torch.int\n" -" %37 = torch.aten.size.int %arg1, %int-1 : !torch.tensor, !torch.int -> !torch.int\n" -" %38 = torch.aten.eq.int %36, %37 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %38 : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.optional) {\n" -" torch.prim.If.yield %arg3 : !torch.optional\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.optional\n" -" }\n" -" %15 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %16 = torch.prim.If %15 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %16 -> () {\n" -" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %37 = torch.prim.If %36 -> (!torch.bool) {\n" -" %38 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %40 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %37 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %38 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %40 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %41 = torch.aten.format(%str_0, %38, %39, %40) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" -" %42 = torch.aten.add.str %str_4, %41 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %42, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.le.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %37 = torch.prim.If %36 -> (!torch.bool) {\n" -" %38 = torch.aten.numel %arg0 : !torch.tensor -> !torch.int\n" -" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %39 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %37 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %38 = torch.aten.size %arg0 : !torch.tensor -> !torch.list\n" -" %39 = torch.aten.format(%str, %38) : !torch.str, !torch.list -> !torch.str\n" -" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" }\n" -" %17 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %18 = torch.aten.lt.int %17, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %19 = torch.prim.If %18 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %20 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %21 = torch.prim.If %20 -> (!torch.tensor) {\n" -" %35 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %22 = torch.aten.unsqueeze %arg2, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %23 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %24 = torch.operator \"aten.scatter.value\"(%23, %19, %22, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %25 = torch.aten.dim %24 : !torch.tensor -> !torch.int\n" -" %26 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" -" %27 = torch.aten.gt.int %25, %26 : !torch.int, !torch.int -> !torch.bool\n" -" %28 = torch.prim.If %27 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.gt.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %35 = torch.aten.unsqueeze %21, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %21 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.__isnot__ %14, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" -" %35 = torch.prim.unchecked_cast %14 : !torch.optional -> !torch.tensor\n" -" %36 = torch.prim.ListConstruct : () -> !torch.list\n" -" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %37, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %42 = torch.aten.append.t %36, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %38 = torch.aten.size.int %35, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten._set_item.t %36, %19, %38 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %40 = torch.aten.reshape %35, %36 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %41 = torch.aten.mul.Tensor %29, %40 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %41 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" }\n" -" %32 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %33 = torch.prim.If %32 -> (!torch.tensor) {\n" -" %35 = torch.aten.ne.Scalar %22, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %36 = torch.aten.where.ScalarOther %35, %31, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %36 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %31 : !torch.tensor\n" -" }\n" -" %34 = torch.aten.mul.Tensor %24, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %34 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %true = torch.constant.bool true\n" -" %false = torch.constant.bool false\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %none = torch.constant.none\n" -" %int2 = torch.constant.int 2\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.lt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %3 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.tensor) {\n" -" %18 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %18 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %5 = torch.aten.unsqueeze %arg2, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %6 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %7 = torch.operator \"aten.scatter.value\"(%6, %2, %5, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %8 = torch.aten.dim %7 : !torch.tensor -> !torch.int\n" -" %9 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" -" %10 = torch.aten.gt.int %8, %9 : !torch.int, !torch.int -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %18 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" -" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %19 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %12 = torch.prim.If %11 -> (!torch.tensor) {\n" -" %18 = torch.aten.unsqueeze %4, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %18 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" %13 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %14 = torch.prim.If %13 -> (!torch.tensor) {\n" -" %18 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %19 = torch.prim.ListConstruct : () -> !torch.list\n" -" %20 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %20, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %25 = torch.aten.append.t %19, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %21 = torch.aten.size.int %18, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %22 = torch.aten._set_item.t %19, %2, %21 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %23 = torch.aten.reshape %18, %19 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %24 = torch.aten.mul.Tensor %12, %23 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %24 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %12 : !torch.tensor\n" -" }\n" -" %15 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %16 = torch.prim.If %15 -> (!torch.tensor) {\n" -" %18 = torch.aten.ne.Scalar %5, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %19 = torch.aten.where.ScalarOther %18, %14, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %19 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %14 : !torch.tensor\n" -" }\n" -" %17 = torch.aten.mul.Tensor %7, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %17 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.nll_loss2d_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %true = torch.constant.bool true\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %str = torch.constant.str \"expected total_weight to be a single element tensor, got: {} ( {}, elements)\"\n" -" %str_0 = torch.constant.str \"size mismatch (got input: {}, target: {}\"\n" -" %false = torch.constant.bool false\n" -" %str_1 = torch.constant.str \"only batches of spatial targets supported (3D tensors) but got targets of dimension: {}\"\n" -" %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: \"\n" -" %str_3 = torch.constant.str \"only batches of spatial inputs supported (4D tensors), but got input of dimension: {}\"\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.format(%str_3, %29) : !torch.str, !torch.int -> !torch.str\n" -" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.format(%str_1, %29) : !torch.str, !torch.int -> !torch.str\n" -" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %5 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %29 = torch.aten.size.int %arg1, %int2 : !torch.tensor, !torch.int -> !torch.int\n" -" %30 = torch.aten.size.int %arg2, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %31 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %29 = torch.aten.size.int %arg1, %int3 : !torch.tensor, !torch.int -> !torch.int\n" -" %30 = torch.aten.size.int %arg2, %int2 : !torch.tensor, !torch.int -> !torch.int\n" -" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %31 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %30 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" -" %31 = torch.aten.format(%str_0, %29, %30) : !torch.str, !torch.list, !torch.list -> !torch.str\n" -" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %9 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" -" %30 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %31 = torch.aten.format(%str, %29, %30) : !torch.str, !torch.list, !torch.int -> !torch.str\n" -" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %11 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %12 = torch.aten.lt.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %14 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %15 = torch.prim.If %14 -> (!torch.tensor) {\n" -" %29 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %16 = torch.aten.unsqueeze %arg2, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %17 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %18 = torch.operator \"aten.scatter.value\"(%17, %13, %16, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %19 = torch.aten.dim %18 : !torch.tensor -> !torch.int\n" -" %20 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" -" %21 = torch.aten.gt.int %19, %20 : !torch.int, !torch.int -> !torch.bool\n" -" %22 = torch.prim.If %21 -> (!torch.bool) {\n" -" %29 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.gt.int %29, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %30 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %23 = torch.prim.If %22 -> (!torch.tensor) {\n" -" %29 = torch.aten.unsqueeze %15, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %15 : !torch.tensor\n" -" }\n" -" %24 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %25 = torch.prim.If %24 -> (!torch.tensor) {\n" -" %29 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %30 = torch.prim.ListConstruct : () -> !torch.list\n" -" %31 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %31, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %36 = torch.aten.append.t %30, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %32 = torch.aten.size.int %29, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %33 = torch.aten._set_item.t %30, %13, %32 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %34 = torch.aten.reshape %29, %30 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %23, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %23 : !torch.tensor\n" -" }\n" -" %26 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %27 = torch.prim.If %26 -> (!torch.tensor) {\n" -" %29 = torch.aten.ne.Scalar %16, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %30 = torch.aten.where.ScalarOther %29, %25, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %30 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %25 : !torch.tensor\n" -" }\n" -" %28 = torch.aten.mul.Tensor %18, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %28 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._log_softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" -" %false = torch.constant.bool false\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %true = torch.constant.bool true\n" -" %0 = torch.aten.exp %arg1 : !torch.tensor -> !torch.tensor\n" -" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" -" %2 = torch.aten.sum.dim_IntList %arg0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %3 = torch.aten.mul.Tensor %0, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.sub.Tensor %arg0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %6 = torch.aten.ne.int %5, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %8 = torch.aten.to.dtype %4, %arg3, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %8 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" return %7 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._cast_grad_to_input_dtype(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int) -> !torch.tensor {\n" -" %none = torch.constant.none\n" -" %false = torch.constant.bool false\n" -" %0 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.ne.int %0, %arg2 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.tensor) {\n" -" %3 = torch.aten.to.dtype %arg1, %arg2, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %3 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg1 : !torch.tensor\n" -" }\n" -" return %2 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" -" %false = torch.constant.bool false\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %true = torch.constant.bool true\n" -" %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" -" %2 = torch.aten.sum.dim_IntList %0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %3 = torch.aten.mul.Tensor %arg1, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.sub.Tensor %0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %6 = torch.aten.ne.int %5, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %8 = torch.aten.to.dtype %4, %arg3, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %8 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" return %7 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.log_sigmoid_forward(%arg0: !torch.tensor) -> !torch.tuple {\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.new_zeros %arg0, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %2 = torch.aten.minimum %1, %arg0 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %3 = torch.aten.abs %arg0 : !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.neg %3 : !torch.tensor -> !torch.tensor\n" -" %5 = torch.aten.exp %4 : !torch.tensor -> !torch.tensor\n" -" %6 = torch.operator \"prim.is_cuda\"(%arg0) : (!torch.tensor) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %11 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %12 = torch.aten.new_zeros %arg0, %11, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %12 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %5 : !torch.tensor\n" -" }\n" -" %8 = torch.aten.log1p %5 : !torch.tensor -> !torch.tensor\n" -" %9 = torch.aten.sub.Tensor %2, %8, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %10 = torch.prim.TupleConstruct %9, %7 : !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %10 : !torch.tuple\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_layer_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.tensor, %arg4: !torch.tensor, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.list) -> !torch.tuple, optional, optional> {\n" -" %false = torch.constant.bool false\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %int2 = torch.constant.int 2\n" -" %0 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" -" %4 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %5 = torch.aten.slice.t %0, %none, %3, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" -" %6 = torch.prim.ListConstruct : () -> !torch.list\n" -" %7 = torch.aten.__range_length %3, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %7, %true, init() {\n" -" ^bb0(%arg8: !torch.int):\n" -" %17 = torch.aten.__derive_index %arg8, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten.append.t %6, %17 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %8 = torch.prim.ListConstruct : () -> !torch.list\n" -" %9 = torch.aten.__range_length %int0, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %9, %true, init() {\n" -" ^bb0(%arg8: !torch.int):\n" -" %17 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten.append.t %8, %17 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %10 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" -" %11 = torch.prim.Loop %10, %true, init(%int1) {\n" -" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" -" %17 = torch.aten.__getitem__.t %4, %arg8 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %12 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" -" %13 = torch.prim.Loop %12, %true, init(%int1) {\n" -" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" -" %17 = torch.aten.__getitem__.t %5, %arg8 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %14 = torch.aten.le.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %15 = torch.prim.If %14 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %17 = torch.aten.le.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %17 : !torch.bool\n" -" }\n" -" %16 = torch.prim.If %15 -> (!torch.tuple, optional, optional>) {\n" -" %17 = torch.aten.new_zeros %arg1, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %18 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %19 = torch.aten.new_zeros %arg1, %18, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %20 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %21 = torch.aten.new_zeros %arg1, %20, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %22 = torch.prim.TupleConstruct %17, %19, %21 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" %23 = torch.derefine %22 : !torch.tuple to !torch.tuple, optional, optional>\n" -" torch.prim.If.yield %23 : !torch.tuple, optional, optional>\n" -" } else {\n" -" %17 = torch.aten.mean.dim %arg1, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %18 = torch.aten.var.dim %arg1, %6, %false, %true : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %19 = torch.aten.reciprocal %arg4 : !torch.tensor -> !torch.tensor\n" -" %20 = torch.aten.pow.Tensor_Scalar %19, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %21 = torch.aten.sub.Tensor %20, %18, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %22 = torch.aten.detach %21 : !torch.tensor -> !torch.tensor\n" -" %23 = torch.aten.add.Tensor %18, %22, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %24 = torch.aten.sqrt %23 : !torch.tensor -> !torch.tensor\n" -" %25 = torch.aten.reciprocal %24 : !torch.tensor -> !torch.tensor\n" -" %26 = torch.aten.sub.Tensor %arg1, %17, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %27 = torch.aten.mul.Tensor %26, %25 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %28 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %46 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" -" %47 = torch.aten.mul.Tensor %arg0, %46 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.mul.Scalar %29, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %31 = torch.aten.sum.dim_IntList %29, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %32 = torch.aten.mul.Tensor %29, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %33 = torch.aten.sum.dim_IntList %32, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %34 = torch.aten.mul.Tensor %27, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %35 = torch.aten.sub.Tensor %30, %31, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %36 = torch.aten.sub.Tensor %35, %34, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %37 = torch.aten.__getitem__.t %arg7, %int0 : !torch.list, !torch.int -> !torch.bool\n" -" %38 = torch.prim.If %37 -> (!torch.tensor) {\n" -" %46 = torch.aten.div.Scalar %25, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %47 = torch.aten.mul.Tensor %46, %36 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %46 : !torch.tensor\n" -" }\n" -" %39 = torch.aten.__getitem__.t %arg7, %int1 : !torch.list, !torch.int -> !torch.bool\n" -" %40 = torch.prim.If %39 -> (!torch.bool) {\n" -" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %46 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %41 = torch.prim.If %40 -> (!torch.tensor) {\n" -" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" -" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" -" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %50 = torch.aten.sum.dim_IntList %49, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %50 : !torch.tensor\n" -" } else {\n" -" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %48 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" -" %48 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" -" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %48 = torch.prim.ListConstruct : () -> !torch.list\n" -" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" }\n" -" %42 = torch.aten.__getitem__.t %arg7, %int2 : !torch.list, !torch.int -> !torch.bool\n" -" %43 = torch.prim.If %42 -> (!torch.bool) {\n" -" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %46 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %44 = torch.prim.If %43 -> (!torch.tensor) {\n" -" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" -" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" -" %49 = torch.aten.sum.dim_IntList %arg0, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %49 = torch.aten.clone %arg0, %none : !torch.tensor, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %48 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" -" %48 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" -" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %48 = torch.prim.ListConstruct : () -> !torch.list\n" -" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" }\n" -" %45 = torch.prim.TupleConstruct %38, %41, %44 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple, optional, optional>\n" -" torch.prim.If.yield %45 : !torch.tuple, optional, optional>\n" -" }\n" -" return %16 : !torch.tuple, optional, optional>\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.recompute_mean_var(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.tuple {\n" -" %false = torch.constant.bool false\n" -" %none = torch.constant.none\n" -" %int1 = torch.constant.int 1\n" -" %int2 = torch.constant.int 2\n" -" %0 = torch.aten.mean.dim %arg0, %arg2, %arg3, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %1 = torch.aten.var.dim %arg0, %arg2, %false, %arg3 : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %2 = torch.aten.reciprocal %arg1 : !torch.tensor -> !torch.tensor\n" -" %3 = torch.aten.mul.Scalar %2, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %4 = torch.aten.pow.Tensor_Scalar %3, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.aten.sub.Tensor %4, %1, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %6 = torch.aten.detach %5 : !torch.tensor -> !torch.tensor\n" -" %7 = torch.aten.add.Tensor %1, %6, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %8 = torch.aten.sqrt %7 : !torch.tensor -> !torch.tensor\n" -" %9 = torch.aten.reciprocal %8 : !torch.tensor -> !torch.tensor\n" -" %10 = torch.aten.mul.Scalar %9, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %11 = torch.prim.TupleConstruct %0, %10 : !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %11 : !torch.tuple\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.bool, %arg8: !torch.float, %arg9: !torch.list) -> !torch.tuple, optional> {\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" -" %str_0 = torch.constant.str \"AssertionError: when train=True, save_mean and save_invstd are required\"\n" -" %false = torch.constant.bool false\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: rank of the input must be at least 2\"\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %float1.000000e00 = torch.constant.float 1.000000e+00\n" -" %0 = torch.prim.Uninitialized : !torch.tensor\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.ge.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %4 = torch.prim.Loop %3, %true, init(%int1) {\n" -" ^bb0(%arg10: !torch.int, %arg11: !torch.int):\n" -" %34 = torch.aten.size.int %arg1, %arg10 : !torch.tensor, !torch.int -> !torch.int\n" -" %35 = torch.aten.mul.int %arg11, %34 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%35 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %5 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %6 = torch.operator \"aten.div.int\"(%4, %5) : (!torch.int, !torch.int) -> !torch.float\n" -" %7:2 = torch.prim.If %arg7 -> (!torch.tensor, !torch.tensor) {\n" -" %34 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.bool) {\n" -" %52 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %52 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %35 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %36 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %38 = torch.prim.ListConstruct : () -> !torch.list\n" -" %39 = torch.aten.__range_length %int2, %37, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %39, %true, init() {\n" -" ^bb0(%arg10: !torch.int):\n" -" %52 = torch.aten.__derive_index %arg10, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %53 = torch.aten.append.t %38, %52 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %40 = torch.aten.add.t %36, %38 : !torch.list, !torch.list -> !torch.list\n" -" %41 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %42 = torch.prim.If %41 -> (!torch.tensor) {\n" -" %52 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" -" torch.prim.If.yield %52 : !torch.tensor\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.tensor\n" -" }\n" -" %43 = torch.aten.mean.dim %arg1, %40, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %44 = torch.aten.var.dim %arg1, %40, %false, %false : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %45 = torch.aten.reciprocal %42 : !torch.tensor -> !torch.tensor\n" -" %46 = torch.aten.pow.Tensor_Scalar %45, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %47 = torch.aten.sub.Tensor %46, %44, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %48 = torch.aten.detach %47 : !torch.tensor -> !torch.tensor\n" -" %49 = torch.aten.add.Tensor %44, %48, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %50 = torch.aten.sqrt %49 : !torch.tensor -> !torch.tensor\n" -" %51 = torch.aten.reciprocal %50 : !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %43, %51 : !torch.tensor, !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.bool) {\n" -" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %40 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %40 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %36:2 = torch.prim.If %35 -> (!torch.tensor, !torch.tensor) {\n" -" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %40 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.tensor\n" -" torch.prim.If.yield %40, %39 : !torch.tensor, !torch.tensor\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0, %0 : !torch.tensor, !torch.tensor\n" -" }\n" -" %37 = torch.aten.add.Scalar %36#0, %arg8, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor\n" -" %38 = torch.aten.rsqrt %37 : !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %36#1, %38 : !torch.tensor, !torch.tensor\n" -" }\n" -" %8 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" -" %9 = torch.operator \"aten.mul.left_t\"(%8, %1) : (!torch.list, !torch.int) -> !torch.list\n" -" %10 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %11 = torch.aten._set_item.t %9, %int1, %10 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %12 = torch.prim.ListConstruct : () -> !torch.list\n" -" torch.prim.Loop %1, %true, init() {\n" -" ^bb0(%arg10: !torch.int):\n" -" %34 = torch.aten.ne.int %arg10, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %34 -> () {\n" -" %35 = torch.aten.append.t %12, %arg10 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %13 = torch.aten.reshape %7#0, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %14 = torch.aten.div.float %float1.000000e00, %6 : !torch.float, !torch.float -> !torch.float\n" -" %15 = torch.aten.sum.dim_IntList %arg0, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %16 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %17 = torch.aten.mul.Tensor %arg0, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %18 = torch.aten.sum.dim_IntList %17, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %19 = torch.aten.mul.Scalar %15, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" -" %20 = torch.aten.reshape %19, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %21 = torch.aten.mul.Scalar %18, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" -" %22 = torch.aten.mul.Tensor %7#1, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %23 = torch.aten.mul.Tensor %21, %22 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %24 = torch.aten.reshape %23, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %25 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %26 = torch.prim.If %25 -> (!torch.tensor) {\n" -" %34 = torch.aten.reshape %7#1, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %35 = torch.aten.mul.Scalar %34, %float1.000000e00 : !torch.tensor, !torch.float -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %7#1, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %36 = torch.aten.reshape %35, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" torch.prim.If.yield %36 : !torch.tensor\n" -" }\n" -" %27 = torch.prim.If %arg7 -> (!torch.tensor) {\n" -" %34 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %34, %24 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %36 = torch.aten.sub.Tensor %arg0, %35, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %37 = torch.aten.sub.Tensor %36, %20, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %38 = torch.aten.mul.Tensor %37, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %38 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.mul.Tensor %arg0, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" }\n" -" %28 = torch.aten.__getitem__.t %arg9, %int1 : !torch.list, !torch.int -> !torch.bool\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %34 = torch.aten.mul.Tensor %18, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.tensor) {\n" -" %36 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" -" %37 = torch.aten.zeros_like %36, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %37 : !torch.tensor\n" -" } else {\n" -" %36 = torch.prim.ListConstruct : () -> !torch.list\n" -" %37 = torch.aten.zeros %36, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %37 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.__getitem__.t %arg9, %int2 : !torch.list, !torch.int -> !torch.bool\n" -" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" -" torch.prim.If.yield %15 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.zeros_like %15, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" }\n" -" %32 = torch.prim.TupleConstruct %27, %29, %31 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" %33 = torch.derefine %32 : !torch.tuple to !torch.tuple, optional>\n" -" return %33 : !torch.tuple, optional>\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.prod(%arg0: !torch.list) -> !torch.int {\n" -" %true = torch.constant.bool true\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" -" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" -" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.cudnn_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.float, %arg8: !torch.tensor) -> !torch.tuple {\n" -" %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list\n" -" %result0, %result1, %result2 = torch.aten.native_batch_norm_backward %arg1, %arg0, %arg2, %arg3, %arg4, %arg5, %arg6, %true, %arg7, %0 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.optional, !torch.optional, !torch.optional, !torch.optional, !torch.bool, !torch.float, !torch.list -> !torch.tensor, !torch.tensor, !torch.tensor\n" -" %1 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %1 : !torch.tuple\n" -" }\n" " func.func @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" @@ -5629,83 +4758,95 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %12 : !torch.list\n" " }\n" -" func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.optional> {\n" -" %str = torch.constant.str \"AssertionError: Either output_size or scale_factors must be presented\"\n" -" %str_0 = torch.constant.str \"AssertionError: \"\n" -" %str_1 = torch.constant.str \"AssertionError: Must specify exactly one of output_size and scale_factors\"\n" +" func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of output_size and scale_factors\"\n" +" %str_1 = torch.constant.str \"AssertionError: Either output_size or scale_factors must be presented\"\n" +" %false = torch.constant.bool false\n" " %none = torch.constant.none\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" " %int2 = torch.constant.int 2\n" " %int3 = torch.constant.int 3\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %2 = torch.aten.append.t %0, %1 : !torch.list, !torch.int -> !torch.list\n" -" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.append.t %0, %3 : !torch.list, !torch.int -> !torch.list\n" -" %5 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.optional>) {\n" -" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" -" %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" +" %0 = torch.prim.Uninitialized : !torch.optional>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.append.t %1, %2 : !torch.list, !torch.int -> !torch.list\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.append.t %1, %4 : !torch.list, !torch.int -> !torch.list\n" +" %6 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.optional>) {\n" +" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.optional>) {\n" +" torch.prim.If.yield %arg2 : !torch.optional>\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.optional>\n" +" }\n" +" %14 = torch.aten.len.t %11 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %9 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" -" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" +" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.append.t %1, %16 : !torch.list, !torch.int -> !torch.list\n" +" %18 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %13 : !torch.optional>\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.optional>\n" +" }\n" +" %10 = torch.aten.__isnot__ %9, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" %11 = torch.prim.unchecked_cast %9 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %11 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.append.t %0, %11 : !torch.list, !torch.int -> !torch.list\n" -" %13 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list\n" -" %15 = torch.derefine %0 : !torch.list to !torch.optional>\n" -" torch.prim.If.yield %15 : !torch.optional>\n" -" } else {\n" -" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.optional>) {\n" -" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" -" %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %11 = torch.aten.len.t %9 : !torch.list -> !torch.int\n" -" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %12 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %13 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %14 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %15 = torch.operator \"aten.mul.int_float\"(%13, %14) : (!torch.int, !torch.float) -> !torch.float\n" -" %16 = torch.aten.Int.float %15 : !torch.float -> !torch.int\n" -" %17 = torch.aten.append.t %0, %16 : !torch.list, !torch.int -> !torch.list\n" -" %18 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %19 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %20 = torch.operator \"aten.mul.int_float\"(%18, %19) : (!torch.int, !torch.float) -> !torch.float\n" -" %21 = torch.aten.Int.float %20 : !torch.float -> !torch.int\n" -" %22 = torch.aten.append.t %0, %21 : !torch.list, !torch.int -> !torch.list\n" -" %23 = torch.derefine %0 : !torch.list to !torch.optional>\n" -" torch.prim.If.yield %23 : !torch.optional>\n" +" %13 = torch.aten.len.t %11 : !torch.list -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" %9 = torch.derefine %none : !torch.none to !torch.optional>\n" -" torch.prim.If.yield %9 : !torch.optional>\n" +" torch.prim.If.yield\n" " }\n" -" torch.prim.If.yield %8 : !torch.optional>\n" +" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float\n" +" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" +" %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" +" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float\n" +" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" +" %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" " }\n" -" return %6 : !torch.optional>\n" +" return %1 : !torch.list\n" " }\n" " func.func @__torch__.torch.jit._shape_functions.argmax(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" " %true = torch.constant.bool true\n"