torch-mlir/test/Dialect/Torch/simplify-shape-calculations...

371 lines
20 KiB
MLIR
Raw Normal View History

// RUN: torch-mlir-opt -torch-simplify-shape-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @refine_shape_calculate_result$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor to !torch.vtensor<[2,?],unk>
// CHECK: torch.shape.calculate.yield %[[REFINED]] : !torch.vtensor<[2,?],unk>
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[2,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_one_element(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,2],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If %arg2 -> () {
// Clobber element 0 of the list. So we can only know that the result is [?,2] instead of [2,2].
%2 = torch.aten._set_item.t %1, %int0, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_all_elements(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If %arg2 -> () {
// Set an unknown element of the list. This clobbers our knowledge of the whole contents of the list.
// So we can only know that the result is [?,?] instead of [2,2].
%2 = torch.aten._set_item.t %1, %arg1, %int0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// Make sure that information previously in the IR is not lost.
// CHECK-LABEL: func @refine_shape_calculate_result$meet_with_existing_information(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,3],f32>
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor<[?,3],f32>
func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor<[?,3],f32>
} shapes {
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor<[?,3],f32>
return %0 : !torch.vtensor<[?,3],f32>
}
// Don't insert static info casts if not needed.
// CHECK-LABEL: func @refine_shape_calculate_result$user_allows_type_refinement(
// CHECK-NOT: torch.tensor_static_info_cast
func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor {
%int2 = torch.constant.int 2
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
%1 = torch.shape.calculate {
torch.shape.calculate.yield %0 : !torch.vtensor
} shapes {
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %2 : !torch.list<int>
} : !torch.vtensor
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: torch.prim.Print(%[[INT0]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.prim.Print(%[[INT1]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.prim.Print(%[[INT2]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
torch.prim.Loop %int3, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
torch.prim.Print(%arg2, %arg3) : !torch.int, !torch.int
torch.prim.Loop.condition %true, iter(%arg3: !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> (!torch.int)
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @fully_unroll_prim_loop$no_unroll(
// CHECK: torch.prim.Loop
func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
torch.prim.Loop %arg2, %true, init() {
^bb0(%arg3: !torch.int):
torch.prim.Print(%arg2) : !torch.int
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
// CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// Test the different supported mutation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten._set_item.t %1, %int1, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.aten.insert.t %1, %int3, %arg3 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(
// CHECK: torch.aten.append.t
// CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
// The value of the alias %2 is printed, but we don't handle that yet.
torch.prim.Print(%2) : !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$readonly_op_in_child_region(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
// This readonly op in a loop doesn't block us from abstractly interpreting
// the whole block.
torch.prim.Loop %arg1, %true, init() {
^bb0(%arg3: !torch.int):
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Print(%2) : !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// The mutation in the child region prevents us from abstractly interpreting.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_in_child_region(
// CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %arg1, %true, init() {
^bb0(%arg3: !torch.int):
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Print(%2) : !torch.list<int>
// This mutation prevents us from abstractly interpreting.
%3 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$miscompile$list_identity(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
// CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor to !torch.vtensor<[3,3],unk>
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[3,3],unk>
// CHECK: } shapes {
// Notice this torch.prim.ListConstruct....
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<int>) {
// CHECK: torch.prim.If.yield %[[VAL_6]] : !torch.list<int>
// CHECK: } else {
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<int>
// CHECK: }
// .... and this one don't have the same object identity, but should!
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<int>) {
// CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list<int>
// CHECK: } else {
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<int>
// CHECK: }
// CHECK: %[[VAL_10:.*]] = torch.aten.__is__ %[[VAL_11:.*]], %[[VAL_12:.*]] : !torch.list<int>, !torch.list<int> -> !torch.bool
// CHECK: torch.prim.Print(%[[VAL_10]]) : !torch.bool
// CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[3,3],unk>
// CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor
// CHECK: return %[[VAL_13]] : !torch.vtensor
func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
// TODO: Fix this miscompile!
// For the case where %arg2 is true, the resulting IR will miscompile
// because the abstract interpretation of the list ops will create two list
// literals.
// One possible solution would be to know that torch.prim.If.yield creates
// a new SSA name for the same dynamic value (it's not the only thing that
// can do this -- pushing and popping the list onto another list could
// create the same situation). Another possible solution would be to only
// replace a single list literal at a time, and bail out if there are any
// uses of the original list value that are not replaced by the created
// literal.
%3 = torch.prim.If %arg2 -> (!torch.list<int>) {
torch.prim.If.yield %1 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<int>
}
%4 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
%5 = torch.prim.If %arg2 -> (!torch.list<int>) {
torch.prim.If.yield %1 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<int>
}
%6 = torch.aten.__is__ %3, %5 : !torch.list<int>, !torch.list<int> -> !torch.bool
torch.prim.Print(%6) : !torch.bool
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// "Integration test" for basic case of all the patterns working together.
// This test should usually not be the one to catch an issue.
// If it does catch an issue then it indicates a more precise unit test that is
// missing.
// CHECK-LABEL: func @basic_integration(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk>
// CHECK: } shapes {
// CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
// CHECK: %[[SIZE1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE0]], %[[SIZE1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[?,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
%true = torch.constant.bool true
%0 = torch.shape.calculate {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor
torch.shape.calculate.yield %1 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.dim %arg0 : !torch.vtensor<[?,?],unk> -> !torch.int
torch.prim.Loop %2, %true, init() {
^bb0(%arg1: !torch.int):
%3 = torch.aten.size.int %arg0, %arg1 : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
%4 = torch.aten.append.t %1, %3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}