// RUN: torch-mlir-opt -torch-simplify-shape-calculations -split-input-file %s | FileCheck %s // CHECK-LABEL: func @refine_shape_calculate_result$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor { // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor to !torch.vtensor<[2,?],unk> // CHECK: torch.shape.calculate.yield %[[REFINED]] : !torch.vtensor<[2,?],unk> // CHECK: } shapes { // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list // CHECK: } : !torch.vtensor<[2,?],unk> // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { %int2 = torch.constant.int 2 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @refine_shape_calculate_result$clobber_one_element( // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,2],unk> to !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list torch.prim.If %arg2 -> () { // Clobber element 0 of the list. So we can only know that the result is [?,2] instead of [2,2]. %2 = torch.aten._set_item.t %1, %int0, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list torch.prim.If.yield } else { torch.prim.If.yield } torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @refine_shape_calculate_result$clobber_all_elements( // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,?],unk> to !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list torch.prim.If %arg2 -> () { // Set an unknown element of the list. This clobbers our knowledge of the whole contents of the list. // So we can only know that the result is [?,?] instead of [2,2]. %2 = torch.aten._set_item.t %1, %arg1, %int0 : !torch.list, !torch.int, !torch.int -> !torch.list torch.prim.If.yield } else { torch.prim.If.yield } torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // Make sure that information previously in the IR is not lost. // CHECK-LABEL: func @refine_shape_calculate_result$meet_with_existing_information( // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,3],f32> // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor<[?,3],f32> func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor<[?,3],f32> } shapes { %1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor<[?,3],f32> return %0 : !torch.vtensor<[?,3],f32> } // Don't insert static info casts if not needed. // CHECK-LABEL: func @refine_shape_calculate_result$user_allows_type_refinement( // CHECK-NOT: torch.tensor_static_info_cast func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor { %int2 = torch.constant.int 2 %0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %1 = torch.shape.calculate { torch.shape.calculate.yield %0 : !torch.vtensor } shapes { %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list torch.shape.calculate.yield.shapes %2 : !torch.list } : !torch.vtensor %2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor return %2 : !torch.vtensor } // CHECK-LABEL: func @fully_unroll_prim_loop$unroll( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.list) -> !torch.vtensor { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor // CHECK: } shapes { // CHECK: torch.prim.Print(%[[INT0]], %[[INT0]]) : !torch.int, !torch.int // CHECK: torch.prim.Print(%[[INT1]], %[[INT0]]) : !torch.int, !torch.int // CHECK: torch.prim.Print(%[[INT2]], %[[INT0]]) : !torch.int, !torch.int // CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list // CHECK: } : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list) -> !torch.vtensor { %true = torch.constant.bool true %int0 = torch.constant.int 0 %int3 = torch.constant.int 3 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { torch.prim.Loop %int3, %true, init(%int0) { ^bb0(%arg2: !torch.int, %arg3: !torch.int): torch.prim.Print(%arg2, %arg3) : !torch.int, !torch.int torch.prim.Loop.condition %true, iter(%arg3: !torch.int) } : (!torch.int, !torch.bool, !torch.int) -> (!torch.int) torch.shape.calculate.yield.shapes %arg1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @fully_unroll_prim_loop$no_unroll( // CHECK: torch.prim.Loop func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list, %arg2: !torch.int) -> !torch.vtensor { %true = torch.constant.bool true %int3 = torch.constant.int 3 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { torch.prim.Loop %arg2, %true, init() { ^bb0(%arg3: !torch.int): torch.prim.Print(%arg2) : !torch.int torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () torch.shape.calculate.yield.shapes %arg1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @abstractly_interpret_list_ops$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.int, // CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor { // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor { %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.append.t %1, %arg1 : !torch.list, !torch.int -> !torch.list %3 = torch.aten.append.t %1, %arg2 : !torch.list, !torch.int -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // Test the different supported mutation ops. // CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_ops( // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %2 = torch.aten._set_item.t %1, %int1, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list %3 = torch.aten.append.t %1, %arg2 : !torch.list, !torch.int -> !torch.list torch.aten.insert.t %1, %int3, %arg3 : !torch.list, !torch.int, !torch.int torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // Test interspersed mutation and evaluation ops. // CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops( // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torch.vtensor) -> !torch.vtensor { %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.len.t %1 : !torch.list -> !torch.int %3 = torch.aten.append.t %1, %2 : !torch.list, !torch.int -> !torch.list %4 = torch.aten.len.t %1 : !torch.list -> !torch.int %5 = torch.aten.append.t %1, %4 : !torch.list, !torch.int -> !torch.list %6 = torch.aten.len.t %1 : !torch.list -> !torch.int %7 = torch.aten.append.t %1, %6 : !torch.list, !torch.int -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled( // CHECK: torch.aten.append.t // CHECK: torch.aten.append.t func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor { %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.append.t %1, %arg1 : !torch.list, !torch.int -> !torch.list // The value of the alias %2 is printed, but we don't handle that yet. torch.prim.Print(%2) : !torch.list %3 = torch.aten.append.t %1, %arg2 : !torch.list, !torch.int -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @abstractly_interpret_list_ops$readonly_op_in_child_region( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor, // CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor { // CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { %true = torch.constant.bool true %int3 = torch.constant.int 3 %int0 = torch.constant.int 0 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list // This readonly op in a loop doesn't block us from abstractly interpreting // the whole block. torch.prim.Loop %arg1, %true, init() { ^bb0(%arg3: !torch.int): %2 = torch.aten.__getitem__.t %1, %int0 : !torch.list, !torch.int -> !torch.list torch.prim.Print(%2) : !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () %2 = torch.aten.append.t %1, %int3 : !torch.list, !torch.int -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // The mutation in the child region prevents us from abstractly interpreting. // CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_in_child_region( // CHECK: torch.aten.append.t func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { %true = torch.constant.bool true %int3 = torch.constant.int 3 %int0 = torch.constant.int 0 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list torch.prim.Loop %arg1, %true, init() { ^bb0(%arg3: !torch.int): %2 = torch.aten.__getitem__.t %1, %int0 : !torch.list, !torch.int -> !torch.list torch.prim.Print(%2) : !torch.list // This mutation prevents us from abstractly interpreting. %3 = torch.aten.append.t %1, %arg1 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () %2 = torch.aten.append.t %1, %int3 : !torch.list, !torch.int -> !torch.list torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // CHECK-LABEL: func @abstractly_interpret_list_ops$miscompile$list_identity( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.list, // CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor { // CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[VAL_4:.*]] = torch.shape.calculate { // CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor to !torch.vtensor<[3,3],unk> // CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[3,3],unk> // CHECK: } shapes { // Notice this torch.prim.ListConstruct.... // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_7:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list) { // CHECK: torch.prim.If.yield %[[VAL_6]] : !torch.list // CHECK: } else { // CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list // CHECK: } // .... and this one don't have the same object identity, but should! // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list) { // CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list // CHECK: } else { // CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list // CHECK: } // CHECK: %[[VAL_10:.*]] = torch.aten.__is__ %[[VAL_11:.*]], %[[VAL_12:.*]] : !torch.list, !torch.list -> !torch.bool // CHECK: torch.prim.Print(%[[VAL_10]]) : !torch.bool // CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list // CHECK: } : !torch.vtensor<[3,3],unk> // CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor // CHECK: return %[[VAL_13]] : !torch.vtensor func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.vtensor { %true = torch.constant.bool true %int3 = torch.constant.int 3 %int0 = torch.constant.int 0 %0 = torch.shape.calculate { torch.shape.calculate.yield %arg0 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.append.t %1, %int3 : !torch.list, !torch.int -> !torch.list // TODO: Fix this miscompile! // For the case where %arg2 is true, the resulting IR will miscompile // because the abstract interpretation of the list ops will create two list // literals. // One possible solution would be to know that torch.prim.If.yield creates // a new SSA name for the same dynamic value (it's not the only thing that // can do this -- pushing and popping the list onto another list could // create the same situation). Another possible solution would be to only // replace a single list literal at a time, and bail out if there are any // uses of the original list value that are not replaced by the created // literal. %3 = torch.prim.If %arg2 -> (!torch.list) { torch.prim.If.yield %1 : !torch.list } else { torch.prim.If.yield %arg1 : !torch.list } %4 = torch.aten.append.t %1, %int3 : !torch.list, !torch.int -> !torch.list %5 = torch.prim.If %arg2 -> (!torch.list) { torch.prim.If.yield %1 : !torch.list } else { torch.prim.If.yield %arg1 : !torch.list } %6 = torch.aten.__is__ %3, %5 : !torch.list, !torch.list -> !torch.bool torch.prim.Print(%6) : !torch.bool torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor } // "Integration test" for basic case of all the patterns working together. // This test should usually not be the one to catch an issue. // If it does catch an issue then it indicates a more precise unit test that is // missing. // CHECK-LABEL: func @basic_integration( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk> // CHECK: } shapes { // CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int // CHECK: %[[SIZE1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE0]], %[[SIZE1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list // CHECK: } : !torch.vtensor<[?,?],unk> // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { %true = torch.constant.bool true %0 = torch.shape.calculate { %1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor torch.shape.calculate.yield %1 : !torch.vtensor } shapes { %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.dim %arg0 : !torch.vtensor<[?,?],unk> -> !torch.int torch.prim.Loop %2, %true, init() { ^bb0(%arg1: !torch.int): %3 = torch.aten.size.int %arg0, %arg1 : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int %4 = torch.aten.append.t %1, %3 : !torch.list, !torch.int -> !torch.list torch.prim.Loop.condition %true, iter() } : (!torch.int, !torch.bool) -> () torch.shape.calculate.yield.shapes %1 : !torch.list } : !torch.vtensor return %0 : !torch.vtensor }