mirror of https://github.com/llvm/torch-mlir
[RefineTypes] Add Float16Type dtype knowledge support for trivial ops
-- This commit adds Float16Type dtype knowledge support for trivial ops. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>tanyo/fix_upstream
parent
0983a7f93a
commit
47f67853ac
|
@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
||||
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
||||
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
|
||||
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
||||
|
@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
Type dtype = operands[0]->getValue().dtype;
|
||||
if (dtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
if (dtype.isa<BFloat16Type, Float64Type>())
|
||||
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
|
||||
knowledge.dtype = dtype;
|
||||
}
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
|
|
Loading…
Reference in New Issue