mirror of https://github.com/llvm/torch-mlir
Handle uninitialized lattice elements in RefineTypes (#1911)
The data-flow analysis does not always propagate information to the entire graph. This results in some lattice elements being uninitialized. Currently the lattice elements are not checked to see if they are uninitialized before rewriting the graph, potentially resulting in invalid IR (see https://github.com/llvm/torch-mlir/issues/1896). This commit adds handling for uninitialized lattice elements.pull/1914/head snapshot-20230304.767
parent
a4602c674c
commit
d30af8772b
|
@ -1532,12 +1532,16 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
return getRefinedTensorType(tensorType, knowledge);
|
||||
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
|
||||
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
|
||||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
if (knowledge.optional == OptionalKnowledge::isNone)
|
||||
return Torch::NoneType::get(v.getContext());
|
||||
else if (knowledge.optional == OptionalKnowledge::notNone) {
|
||||
|
@ -1552,6 +1556,8 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
if (knowledge.kind == torch_upstream::TypeKind::IntType)
|
||||
return Torch::IntType::get(v.getContext());
|
||||
if (knowledge.kind == torch_upstream::TypeKind::FloatType)
|
||||
|
|
|
@ -212,3 +212,27 @@ func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
|
|||
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The data-flow analysis does not always propagate information to the entire graph.
|
||||
// This results in some lattice elements being uninitialized, which must be properly
|
||||
// handled when using the lattice elements to rewrite the graph.
|
||||
// In this particular case, the presence of the loop causes `torch.copy.to_vtensor`
|
||||
// to end up with an uninitialized lattice element. This is the simplest graph I was
|
||||
// able to come up with that reproduces such behavior.
|
||||
|
||||
// CHECK-LABEL: func.func @uninitialized_lattice_elements(
|
||||
// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32>
|
||||
|
||||
func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> {
|
||||
%true = torch.constant.bool true
|
||||
%1 = torch.constant.int 0
|
||||
%2 = torch.prim.Loop %1, %true, init(%arg3) {
|
||||
^bb0(%arg1: !torch.int, %arg2: !torch.tensor):
|
||||
torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor)
|
||||
} : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor
|
||||
%3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32>
|
||||
%4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
|
||||
return %4 : !torch.vtensor<*,f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue