diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 96e9e78a2..cbe1d5b48 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op, return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } - // Dtype is always float32, except for bfloat16, float64 and nullptr. + // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. if (isa(op)) { @@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op, Type dtype = operands[0]->getValue().dtype; if (dtype) { knowledge.dtype = Float32Type::get(op->getContext()); - if (dtype.isa()) + if (dtype.isa()) knowledge.dtype = dtype; } incorporateKnowledge(op->getResult(0), knowledge);