mirror of https://github.com/llvm/torch-mlir
Clean up shape functions that use `sum_mean_dim` (#1217)
I recently fixed the handling of the `dim` argument in
`sum_mean_dim` (59fccab857
). Therefore,
the checks that the `dim` input is `None` or `[]` are no longer needed.
pull/1228/head
parent
7d4a0d0e2b
commit
f07f7d20f9
|
@ -3494,137 +3494,6 @@ module {
|
|||
%6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
||||
return %6 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||
}
|
||||
func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int
|
||||
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.Loop %8, %true, init() {
|
||||
^bb0(%arg9: !torch.int):
|
||||
%9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
%10 = torch.prim.If %1 -> (!torch.int) {
|
||||
%11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %12 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int1 : !torch.int
|
||||
}
|
||||
torch.prim.If %arg6 -> () {
|
||||
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
|
||||
%14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int
|
||||
%19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int
|
||||
%23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int
|
||||
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
|
||||
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int
|
||||
%20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int
|
||||
%21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int
|
||||
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.If.yield
|
||||
}
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
return %3 : !torch.list<int>
|
||||
}
|
||||
func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.optional<list<int>>, %arg6: !torch.int, %arg7: !torch.optional<list<int>>) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.__is__ %arg3, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
||||
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
} else {
|
||||
%15 = torch.prim.unchecked_cast %arg3 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
}
|
||||
%2 = torch.aten.__is__ %arg4, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%3 = torch.prim.If %2 -> (!torch.list<int>) {
|
||||
%15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
} else {
|
||||
%15 = torch.prim.unchecked_cast %arg4 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
}
|
||||
%4 = torch.aten.__is__ %arg7, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%5 = torch.prim.If %4 -> (!torch.list<int>) {
|
||||
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
} else {
|
||||
%15 = torch.prim.unchecked_cast %arg7 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %15 : !torch.list<int>
|
||||
}
|
||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%11 = torch.aten.append.t %9, %10 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.Loop %14, %true, init() {
|
||||
^bb0(%arg8: !torch.int):
|
||||
%15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||
%16 = torch.prim.If %7 -> (!torch.int) {
|
||||
%32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%33 = torch.aten.__getitem__.t %5, %32 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %33 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int1 : !torch.int
|
||||
}
|
||||
%17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int
|
||||
%20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%23 = torch.aten.__getitem__.t %1, %22 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int
|
||||
%25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%26 = torch.aten.__getitem__.t %3, %25 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
%28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int
|
||||
%29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int
|
||||
%30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%31 = torch.aten.append.t %9, %30 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
return %9 : !torch.list<int>
|
||||
}
|
||||
func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {
|
||||
%none = torch.constant.none
|
||||
%str = torch.constant.str "AssertionError: "
|
||||
|
@ -4502,76 +4371,90 @@ module {
|
|||
}
|
||||
func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
|
||||
%str = torch.constant.str "AssertionError: "
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%1 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.If.yield %4 : !torch.list<int>
|
||||
%2 = torch.prim.If %1 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%4 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %4 : !torch.list<int>
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %7 : !torch.bool
|
||||
}
|
||||
%3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
torch.prim.Loop %3, %true, init() {
|
||||
%3 = torch.prim.If %2 -> (!torch.list<int>) {
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %5, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
torch.prim.Loop %4, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%4 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int
|
||||
%5 = torch.prim.Loop %4, %true, init(%false) {
|
||||
%5 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.Loop %5, %true, init(%false) {
|
||||
^bb0(%arg5: !torch.int, %arg6: !torch.bool):
|
||||
%6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%9 = torch.prim.If %8 -> (!torch.int) {
|
||||
%7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%10 = torch.prim.If %9 -> (!torch.int) {
|
||||
torch.prim.If.yield %int1 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %7 : !torch.int
|
||||
torch.prim.If.yield %8 : !torch.int
|
||||
}
|
||||
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
|
||||
%11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool
|
||||
%13 = torch.prim.If %12 -> (!torch.bool) {
|
||||
%11 = torch.aten.neg.int %10 : !torch.int -> !torch.int
|
||||
%12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool
|
||||
%14 = torch.prim.If %13 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %19 : !torch.bool
|
||||
%20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %20 : !torch.bool
|
||||
}
|
||||
%14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool
|
||||
torch.prim.If %14 -> () {
|
||||
%15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool
|
||||
torch.prim.If %15 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%16 = torch.prim.If %15 -> (!torch.int) {
|
||||
%19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %19 : !torch.int
|
||||
%16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%17 = torch.prim.If %16 -> (!torch.int) {
|
||||
%20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %20 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %6 : !torch.int
|
||||
torch.prim.If.yield %7 : !torch.int
|
||||
}
|
||||
%17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool
|
||||
%18 = torch.prim.If %17 -> (!torch.bool) {
|
||||
%18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool
|
||||
%19 = torch.prim.If %18 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
torch.prim.If.yield %arg6 : !torch.bool
|
||||
}
|
||||
torch.prim.Loop.condition %true, iter(%18 : !torch.bool)
|
||||
torch.prim.Loop.condition %true, iter(%19 : !torch.bool)
|
||||
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
|
||||
torch.prim.If %5 -> () {
|
||||
torch.prim.If %6 -> () {
|
||||
torch.prim.If %arg2 -> () {
|
||||
%6 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%7 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.If.yield
|
||||
}
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
%6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%8 = torch.aten.append.t %0, %7 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.If.yield
|
||||
}
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
|
@ -4582,8 +4465,8 @@ module {
|
|||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%str = torch.constant.str "AssertionError: "
|
||||
%0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
|
@ -5712,101 +5595,26 @@ module {
|
|||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %8 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %7 : !torch.list<int>
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %5 : !torch.list<int>
|
||||
%0 = torch.derefine %none : !torch.none to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %8 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %7 : !torch.list<int>
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %5 : !torch.list<int>
|
||||
%0 = torch.derefine %none : !torch.none to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %8 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %7 : !torch.list<int>
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %5 : !torch.list<int>
|
||||
%0 = torch.derefine %none : !torch.none to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
|
||||
%none = torch.constant.none
|
||||
|
@ -5861,66 +5669,14 @@ module {
|
|||
return %1 : !torch.tuple<list<int>, list<int>>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %8 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %7 : !torch.list<int>
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %5 : !torch.list<int>
|
||||
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %8 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %6, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %7 : !torch.list<int>
|
||||
} else {
|
||||
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %5 : !torch.list<int>
|
||||
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
|
@ -7029,26 +6785,9 @@ module {
|
|||
return %none : !torch.none
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %5, %true, init() {
|
||||
^bb0(%arg5: !torch.int):
|
||||
%7 = torch.aten.append.t %6, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
%2 = torch.derefine %1 : !torch.list<int> to !torch.optional<list<int>>
|
||||
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||
%4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %4 : !torch.list<int>
|
||||
%0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
}
|
||||
)mlir");
|
||||
|
|
|
@ -524,21 +524,15 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
|
|||
return []
|
||||
|
||||
def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
||||
return []
|
||||
|
||||
def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
|
@ -574,13 +568,9 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
|
|||
return reduced_shape, reduced_shape
|
||||
|
||||
def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
||||
|
@ -1169,8 +1159,6 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen
|
|||
return [hacky_get_unknown_dimension_size()]
|
||||
|
||||
def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
if dim is None:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue