Handle uninitialized lattice elements in RefineTypes (#1911)

The data-flow analysis does not always propagate information to the
entire graph. This results in some lattice elements being
uninitialized. Currently the lattice elements are not checked to see
if they are uninitialized before rewriting the graph, potentially
resulting in invalid IR (see
https://github.com/llvm/torch-mlir/issues/1896).

This commit adds handling for uninitialized lattice elements.
pull/1914/head snapshot-20230304.767
Ramiro Leal-Cavazos 2023-03-03 08:55:58 -08:00 committed by GitHub
parent a4602c674c
commit d30af8772b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 0 deletions

View File

@ -1532,12 +1532,16 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
return getRefinedTensorType(tensorType, knowledge);
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.optional == OptionalKnowledge::isNone)
return Torch::NoneType::get(v.getContext());
else if (knowledge.optional == OptionalKnowledge::notNone) {
@ -1552,6 +1556,8 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.kind == torch_upstream::TypeKind::IntType)
return Torch::IntType::get(v.getContext());
if (knowledge.kind == torch_upstream::TypeKind::FloatType)

View File

@ -212,3 +212,27 @@ func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
return
}
// -----
// The data-flow analysis does not always propagate information to the entire graph.
// This results in some lattice elements being uninitialized, which must be properly
// handled when using the lattice elements to rewrite the graph.
// In this particular case, the presence of the loop causes `torch.copy.to_vtensor`
// to end up with an uninitialized lattice element. This is the simplest graph I was
// able to come up with that reproduces such behavior.
// CHECK-LABEL: func.func @uninitialized_lattice_elements(
// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32>
func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> {
%true = torch.constant.bool true
%1 = torch.constant.int 0
%2 = torch.prim.Loop %1, %true, init(%arg3) {
^bb0(%arg1: !torch.int, %arg2: !torch.tensor):
torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor)
} : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor
%3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32>
%4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
return %4 : !torch.vtensor<*,f32>
}