Clean up shape functions that use `sum_mean_dim` (#1217)

I recently fixed the handling of the `dim` argument in
`sum_mean_dim` (59fccab857). Therefore,
the checks that the `dim` input is `None` or `[]` are no longer needed.
pull/1228/head
Ramiro Leal-Cavazos 2022-08-18 08:23:43 -07:00 committed by GitHub
parent 7d4a0d0e2b
commit f07f7d20f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 341 deletions

View File

@ -3494,137 +3494,6 @@ module {
%6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
return %6 : !torch.tuple<list<int>, list<int>, list<int>>
}
func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg9: !torch.int):
%9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%10 = torch.prim.If %1 -> (!torch.int) {
%11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
torch.prim.If %arg6 -> () {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list<int>, !torch.int -> !torch.int
%23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.optional<list<int>>, %arg6: !torch.int, %arg7: !torch.optional<list<int>>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.aten.__is__ %arg3, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg3 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%2 = torch.aten.__is__ %arg4, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg4 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%4 = torch.aten.__is__ %arg7, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg7 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%11 = torch.aten.append.t %9, %10 : !torch.list<int>, !torch.int -> !torch.list<int>
%12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
%14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %14, %true, init() {
^bb0(%arg8: !torch.int):
%15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%16 = torch.prim.If %7 -> (!torch.int) {
%32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%33 = torch.aten.__getitem__.t %5, %32 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.__getitem__.t %1, %22 : !torch.list<int>, !torch.int -> !torch.int
%24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%26 = torch.aten.__getitem__.t %3, %25 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int
%28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int
%29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.append.t %9, %30 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %9 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
@ -4502,76 +4371,90 @@ module {
}
func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
%2 = torch.prim.If %1 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%4 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %3, %true, init() {
%3 = torch.prim.If %2 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %4, %true, init() {
^bb0(%arg4: !torch.int):
%4 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int
%5 = torch.prim.Loop %4, %true, init(%false) {
%5 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int
%6 = torch.prim.Loop %5, %true, init(%false) {
^bb0(%arg5: !torch.int, %arg6: !torch.bool):
%6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %7 : !torch.int
torch.prim.If.yield %8 : !torch.int
}
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%11 = torch.aten.neg.int %10 : !torch.int -> !torch.int
%12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
%20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
}
%14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.int) {
%19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
%16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %6 : !torch.int
torch.prim.If.yield %7 : !torch.int
}
%17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.bool) {
%18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %arg6 : !torch.bool
}
torch.prim.Loop.condition %true, iter(%18 : !torch.bool)
torch.prim.Loop.condition %true, iter(%19 : !torch.bool)
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
torch.prim.If %5 -> () {
torch.prim.If %6 -> () {
torch.prim.If %arg2 -> () {
%6 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
%7 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
%7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%8 = torch.aten.append.t %0, %7 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
@ -4582,8 +4465,8 @@ module {
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%str = torch.constant.str "AssertionError: "
%0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
@ -5712,101 +5595,26 @@ module {
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %7 : !torch.list<int>
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %7 : !torch.list<int>
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %7 : !torch.list<int>
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
@ -5861,66 +5669,14 @@ module {
return %1 : !torch.tuple<list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %7 : !torch.list<int>
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %7 : !torch.list<int>
} else {
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
@ -7029,26 +6785,9 @@ module {
return %none : !torch.none
}
func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%0 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg5: !torch.int):
%7 = torch.aten.append.t %6, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%2 = torch.derefine %1 : !torch.list<int> to !torch.optional<list<int>>
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
}
)mlir");

View File

@ -524,21 +524,15 @@ def atenvar(self: List[int], unbiased: bool = True) -> List[int]:
return []
def atenvardim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenvarcorrection(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
return []
def atenstddim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
@ -574,13 +568,9 @@ def atenmaxdim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
return reduced_shape, reduced_shape
def atenmeandim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atensumdim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atenpermute(self: List[int], dims: List[int]) -> List[int]:
@ -1169,8 +1159,6 @@ def atenbincount(self: List[int], weights: Optional[List[int]] = None, minlen
return [hacky_get_unknown_dimension_size()]
def atenlinalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None:
dim = list(range(len(self)))
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
# ==============================================================================