mirror of https://github.com/llvm/torch-mlir
Bump the shape lib to match the upstream functions currently in PyTorch (#1236)
Bumps the shape library: - Updates the function signature for aten.arange.start_step - upstream_shape_functions.mean_dim -> upstream_shape_functions.sum_mean_dimpull/1217/head
parent
11a5b5ac52
commit
85f383ce0b
|
@ -28,5 +28,5 @@ fi
|
||||||
|
|
||||||
PYTHONPATH="${pypath}" python \
|
PYTHONPATH="${pypath}" python \
|
||||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
|
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
|
||||||
--pytorch_op_extensions=${ext_module} \
|
--pytorch_op_extensions=${ext_module:-""} \
|
||||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
||||||
|
|
|
@ -3494,6 +3494,137 @@ module {
|
||||||
%6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
%6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
|
||||||
return %6 : !torch.tuple<list<int>, list<int>, list<int>>
|
return %6 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||||
}
|
}
|
||||||
|
func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int
|
||||||
|
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||||
|
torch.prim.Loop %8, %true, init() {
|
||||||
|
^bb0(%arg9: !torch.int):
|
||||||
|
%9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||||
|
%10 = torch.prim.If %1 -> (!torch.int) {
|
||||||
|
%11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
torch.prim.If.yield %12 : !torch.int
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %int1 : !torch.int
|
||||||
|
}
|
||||||
|
torch.prim.If %arg6 -> () {
|
||||||
|
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.If.yield
|
||||||
|
} else {
|
||||||
|
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.If.yield
|
||||||
|
}
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
return %3 : !torch.list<int>
|
||||||
|
}
|
||||||
|
func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.optional<list<int>>, %arg6: !torch.int, %arg7: !torch.optional<list<int>>) -> !torch.list<int> {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.__is__ %arg3, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
|
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
||||||
|
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
%15 = torch.prim.unchecked_cast %arg3 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%2 = torch.aten.__is__ %arg4, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
|
%3 = torch.prim.If %2 -> (!torch.list<int>) {
|
||||||
|
%15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
%15 = torch.prim.unchecked_cast %arg4 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%4 = torch.aten.__is__ %arg7, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
|
%5 = torch.prim.If %4 -> (!torch.list<int>) {
|
||||||
|
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
%15 = torch.prim.unchecked_cast %arg7 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %15 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||||
|
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%11 = torch.aten.append.t %9, %10 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
%12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
%14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||||
|
torch.prim.Loop %14, %true, init() {
|
||||||
|
^bb0(%arg8: !torch.int):
|
||||||
|
%15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
|
||||||
|
%16 = torch.prim.If %7 -> (!torch.int) {
|
||||||
|
%32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%33 = torch.aten.__getitem__.t %5, %32 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
torch.prim.If.yield %33 : !torch.int
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %int1 : !torch.int
|
||||||
|
}
|
||||||
|
%17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%23 = torch.aten.__getitem__.t %1, %22 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%26 = torch.aten.__getitem__.t %3, %25 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
%27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
%31 = torch.aten.append.t %9, %30 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
return %9 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {
|
func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%str = torch.constant.str "AssertionError: "
|
%str = torch.constant.str "AssertionError: "
|
||||||
|
@ -4369,70 +4500,78 @@ module {
|
||||||
}
|
}
|
||||||
return %6 : !torch.list<int>
|
return %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @__torch__.torch.jit._shape_functions.mean_dim(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
|
func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
|
||||||
%none = torch.constant.none
|
|
||||||
%str = torch.constant.str "AssertionError: "
|
%str = torch.constant.str "AssertionError: "
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
%none = torch.constant.none
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%1 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
torch.prim.Loop %1, %true, init() {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
|
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %4 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
%4 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %4 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
torch.prim.Loop %3, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
%4 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int
|
||||||
%3 = torch.prim.Loop %2, %true, init(%false) {
|
%5 = torch.prim.Loop %4, %true, init(%false) {
|
||||||
^bb0(%arg5: !torch.int, %arg6: !torch.bool):
|
^bb0(%arg5: !torch.int, %arg6: !torch.bool):
|
||||||
%4 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
%6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
%7 = torch.prim.If %6 -> (!torch.int) {
|
%9 = torch.prim.If %8 -> (!torch.int) {
|
||||||
torch.prim.If.yield %int1 : !torch.int
|
torch.prim.If.yield %int1 : !torch.int
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield %5 : !torch.int
|
torch.prim.If.yield %7 : !torch.int
|
||||||
}
|
}
|
||||||
%8 = torch.aten.neg.int %7 : !torch.int -> !torch.int
|
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
|
||||||
%9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int
|
%11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int
|
||||||
%10 = torch.aten.lt.int %4, %8 : !torch.int, !torch.int -> !torch.bool
|
%12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool
|
||||||
%11 = torch.prim.If %10 -> (!torch.bool) {
|
%13 = torch.prim.If %12 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%17 = torch.aten.gt.int %4, %9 : !torch.int, !torch.int -> !torch.bool
|
%19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %17 : !torch.bool
|
torch.prim.If.yield %19 : !torch.bool
|
||||||
}
|
}
|
||||||
%12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool
|
%14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool
|
||||||
torch.prim.If %12 -> () {
|
torch.prim.If %14 -> () {
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
} else {
|
} else {
|
||||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
}
|
}
|
||||||
%13 = torch.aten.lt.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
|
%15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
%14 = torch.prim.If %13 -> (!torch.int) {
|
%16 = torch.prim.If %15 -> (!torch.int) {
|
||||||
%17 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int
|
%19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int
|
||||||
torch.prim.If.yield %17 : !torch.int
|
torch.prim.If.yield %19 : !torch.int
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield %4 : !torch.int
|
torch.prim.If.yield %6 : !torch.int
|
||||||
}
|
}
|
||||||
%15 = torch.aten.eq.int %arg4, %14 : !torch.int, !torch.int -> !torch.bool
|
%17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool
|
||||||
%16 = torch.prim.If %15 -> (!torch.bool) {
|
%18 = torch.prim.If %17 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield %arg6 : !torch.bool
|
torch.prim.If.yield %arg6 : !torch.bool
|
||||||
}
|
}
|
||||||
torch.prim.Loop.condition %true, iter(%16 : !torch.bool)
|
torch.prim.Loop.condition %true, iter(%18 : !torch.bool)
|
||||||
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
|
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
|
||||||
torch.prim.If %3 -> () {
|
torch.prim.If %5 -> () {
|
||||||
torch.prim.If %arg2 -> () {
|
torch.prim.If %arg2 -> () {
|
||||||
%4 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%6 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
} else {
|
} else {
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
}
|
}
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
} else {
|
} else {
|
||||||
%4 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
|
%6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
%5 = torch.aten.append.t %0, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
}
|
}
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
@ -4442,10 +4581,10 @@ module {
|
||||||
func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {
|
func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
%none = torch.constant.none
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
%str = torch.constant.str "AssertionError: "
|
%str = torch.constant.str "AssertionError: "
|
||||||
%none = torch.constant.none
|
|
||||||
%0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
%0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
||||||
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
@ -5441,10 +5580,6 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
|
||||||
return %0 : !torch.list<int>
|
|
||||||
}
|
|
||||||
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
@ -5524,6 +5659,10 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
@ -5580,27 +5719,28 @@ module {
|
||||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %7 : !torch.bool
|
torch.prim.If.yield %8 : !torch.bool
|
||||||
}
|
}
|
||||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %5, %true, init() {
|
torch.prim.Loop %6, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %6 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||||
return %4 : !torch.list<int>
|
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
@ -5610,27 +5750,28 @@ module {
|
||||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %7 : !torch.bool
|
torch.prim.If.yield %8 : !torch.bool
|
||||||
}
|
}
|
||||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %5, %true, init() {
|
torch.prim.Loop %6, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %6 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||||
return %4 : !torch.list<int>
|
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
|
||||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
@ -5644,27 +5785,28 @@ module {
|
||||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %7 : !torch.bool
|
torch.prim.If.yield %8 : !torch.bool
|
||||||
}
|
}
|
||||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %5, %true, init() {
|
torch.prim.Loop %6, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %6 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%4 = torch.derefine %none : !torch.none to !torch.any
|
||||||
return %4 : !torch.list<int>
|
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
|
@ -5726,27 +5868,28 @@ module {
|
||||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %7 : !torch.bool
|
torch.prim.If.yield %8 : !torch.bool
|
||||||
}
|
}
|
||||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %5, %true, init() {
|
torch.prim.Loop %6, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %6 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||||
return %4 : !torch.list<int>
|
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
@ -5756,27 +5899,28 @@ module {
|
||||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
torch.prim.If.yield %true : !torch.bool
|
torch.prim.If.yield %true : !torch.bool
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
|
||||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %7 : !torch.bool
|
torch.prim.If.yield %8 : !torch.bool
|
||||||
}
|
}
|
||||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %5, %true, init() {
|
torch.prim.Loop %6, %true, init() {
|
||||||
^bb0(%arg4: !torch.int):
|
^bb0(%arg4: !torch.int):
|
||||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %6 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||||
return %4 : !torch.list<int>
|
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
|
@ -6889,21 +7033,22 @@ module {
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%0 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
%0 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
||||||
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
%5 = torch.prim.ListConstruct : () -> !torch.list<int>
|
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
torch.prim.Loop %4, %true, init() {
|
torch.prim.Loop %5, %true, init() {
|
||||||
^bb0(%arg5: !torch.int):
|
^bb0(%arg5: !torch.int):
|
||||||
%6 = torch.aten.append.t %5, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
|
%7 = torch.aten.append.t %6, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
torch.prim.Loop.condition %true, iter()
|
torch.prim.Loop.condition %true, iter()
|
||||||
} : (!torch.int, !torch.bool) -> ()
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
torch.prim.If.yield %5 : !torch.list<int>
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
} else {
|
} else {
|
||||||
%4 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
|
%5 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
torch.prim.If.yield %4 : !torch.list<int>
|
torch.prim.If.yield %5 : !torch.list<int>
|
||||||
}
|
}
|
||||||
%2 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
%2 = torch.derefine %1 : !torch.list<int> to !torch.optional<list<int>>
|
||||||
%3 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %1, %arg3, %2) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||||
return %3 : !torch.list<int>
|
%4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %4 : !torch.list<int>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)mlir");
|
)mlir");
|
||||||
|
|
|
@ -526,12 +526,12 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
|
||||||
def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||||
if dim is None or len(dim)==0:
|
if dim is None or len(dim)==0:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
||||||
if dim is None or len(dim)==0:
|
if dim is None or len(dim)==0:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
@ -539,7 +539,7 @@ def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
||||||
def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||||
if dim is None or len(dim)==0:
|
if dim is None or len(dim)==0:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||||
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
|
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
|
||||||
|
@ -576,12 +576,12 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
|
||||||
def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
if dim is None or len(dim)==0:
|
if dim is None or len(dim)==0:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
if dim is None or len(dim)==0:
|
if dim is None or len(dim)==0:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.permute(self, dims)
|
return upstream_shape_functions.permute(self, dims)
|
||||||
|
@ -813,7 +813,7 @@ def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]:
|
||||||
def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇arange〇start_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
||||||
|
|
||||||
def aten〇arange〇start(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇arange〇start(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
|
@ -1171,7 +1171,7 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen
|
||||||
def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = list(range(len(self)))
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Shape library generator main().
|
# Shape library generator main().
|
||||||
|
|
Loading…
Reference in New Issue