diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 95251c8d0..ad668ecad 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -893,6 +893,13 @@ void TypeAnalysis::visitOperation(Operation *op, if (auto sum = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + if (!defaultDtype) { + incorporateKnowledge( + sum.getResult(), + ValueKnowledge::getTensorPessimisticValueState(op->getContext())); + return; + } + // If the input dtype is bool, the result type should be i64. if (defaultDtype.isInteger(1)) defaultDtype =