Bump the shape lib to match the upstream functions currently in PyTorch (#1236)

Bumps the shape library:
 - Updates the function signature for aten.arange.start_step
 - upstream_shape_functions.mean_dim -> upstream_shape_functions.sum_mean_dim
pull/1217/head
Quinn Dawkins 2022-08-17 00:11:04 -04:00 committed by GitHub
parent 11a5b5ac52
commit 85f383ce0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 269 additions and 124 deletions

View File

@ -28,5 +28,5 @@ fi
PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
--pytorch_op_extensions=${ext_module} \
--pytorch_op_extensions=${ext_module:-""} \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"

View File

@ -3494,6 +3494,137 @@ module {
%6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
return %6 : !torch.tuple<list<int>, list<int>, list<int>>
}
func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg9: !torch.int):
%9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%10 = torch.prim.If %1 -> (!torch.int) {
%11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
torch.prim.If %arg6 -> () {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list<int>, !torch.int -> !torch.int
%23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.optional<list<int>>, %arg6: !torch.int, %arg7: !torch.optional<list<int>>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.aten.__is__ %arg3, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg3 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%2 = torch.aten.__is__ %arg4, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg4 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%4 = torch.aten.__is__ %arg7, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.list<int>) {
%15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
} else {
%15 = torch.prim.unchecked_cast %arg7 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %15 : !torch.list<int>
}
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
%8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%11 = torch.aten.append.t %9, %10 : !torch.list<int>, !torch.int -> !torch.list<int>
%12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
%14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %14, %true, init() {
^bb0(%arg8: !torch.int):
%15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%16 = torch.prim.If %7 -> (!torch.int) {
%32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%33 = torch.aten.__getitem__.t %5, %32 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.__getitem__.t %1, %22 : !torch.list<int>, !torch.int -> !torch.int
%24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int
%26 = torch.aten.__getitem__.t %3, %25 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int
%28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int
%29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.append.t %9, %30 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %9 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
@ -4369,70 +4500,78 @@ module {
}
return %6 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.mean_dim(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
%none = torch.constant.none
func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
%1 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
} else {
%4 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
}
%3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %3, %true, init() {
^bb0(%arg4: !torch.int):
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%3 = torch.prim.Loop %2, %true, init(%false) {
%4 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int
%5 = torch.prim.Loop %4, %true, init(%false) {
^bb0(%arg5: !torch.int, %arg6: !torch.bool):
%4 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.int) {
%6 = torch.aten.__getitem__.t %2, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%8 = torch.aten.le.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %5 : !torch.int
torch.prim.If.yield %7 : !torch.int
}
%8 = torch.aten.neg.int %7 : !torch.int -> !torch.int
%9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.lt.int %4, %8 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.bool) {
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.lt.int %6, %10 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%17 = torch.aten.gt.int %4, %9 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
%19 = torch.aten.gt.int %6, %11 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
%12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool
torch.prim.If %12 -> () {
%14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool
torch.prim.If %14 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%13 = torch.aten.lt.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%17 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %17 : !torch.int
%15 = torch.aten.lt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.int) {
%19 = torch.aten.add.int %6, %9 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.If.yield %4 : !torch.int
torch.prim.If.yield %6 : !torch.int
}
%15 = torch.aten.eq.int %arg4, %14 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%17 = torch.aten.eq.int %arg4, %16 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %arg6 : !torch.bool
}
torch.prim.Loop.condition %true, iter(%16 : !torch.bool)
torch.prim.Loop.condition %true, iter(%18 : !torch.bool)
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
torch.prim.If %3 -> () {
torch.prim.If %5 -> () {
torch.prim.If %arg2 -> () {
%4 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%4 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.append.t %0, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
@ -4442,10 +4581,10 @@ module {
func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%str = torch.constant.str "AssertionError: "
%none = torch.constant.none
%0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
@ -5441,10 +5580,6 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
@ -5524,6 +5659,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
@ -5580,27 +5719,28 @@ module {
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
torch.prim.If.yield %7 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
@ -5610,27 +5750,28 @@ module {
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
torch.prim.If.yield %7 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
@ -5644,27 +5785,28 @@ module {
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
torch.prim.If.yield %7 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %none : !torch.none to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
@ -5726,27 +5868,28 @@ module {
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
torch.prim.If.yield %7 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
@ -5756,27 +5899,28 @@ module {
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%7 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %8 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
%6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.append.t %7, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
torch.prim.If.yield %7 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
%6 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
}
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
%3 = torch.derefine %2 : !torch.list<int> to !torch.optional<list<int>>
%4 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%5 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %3, %arg2, %4) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %5 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
@ -6889,21 +7033,22 @@ module {
%none = torch.constant.none
%0 = torch.aten.__is__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%5 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %4, %true, init() {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg5: !torch.int):
%6 = torch.aten.append.t %5, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
%7 = torch.aten.append.t %6, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %5 : !torch.list<int>
torch.prim.If.yield %6 : !torch.list<int>
} else {
%4 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
%5 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%2 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%3 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %1, %arg3, %2) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %3 : !torch.list<int>
%2 = torch.derefine %1 : !torch.list<int> to !torch.optional<list<int>>
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
}
)mlir");

View File

@ -526,12 +526,12 @@ def atenvar(self: List[int], unbiased: bool = True) -> List[int]:
def atenvardim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenvarcorrection(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
return []
@ -539,7 +539,7 @@ def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
def atenstddim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
@ -576,12 +576,12 @@ def atenmaxdim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
def atenmeandim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atensumdim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
def atenpermute(self: List[int], dims: List[int]) -> List[int]:
return upstream_shape_functions.permute(self, dims)
@ -813,7 +813,7 @@ def atenbernoulli(self: List[int], generator: Any = None) -> List[int]:
def atenrand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self
def atenarangestart_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
def atenarangestart_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
def atenarangestart(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
@ -1171,7 +1171,7 @@ def atenbincount(self: List[int], weights: Optional[List[int]] = None, minlen
def atenlinalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
# ==============================================================================
# Shape library generator main().