mirror of https://github.com/llvm/torch-mlir
[RefineTypes] Add Float16Type dtype knowledge support for trivial ops
-- This commit adds Float16Type dtype knowledge support for trivial ops. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>tanyo/fix_upstream
parent
0983a7f93a
commit
47f67853ac
|
@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
||||||
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
|
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
|
||||||
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
|
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
|
||||||
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
||||||
|
@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
Type dtype = operands[0]->getValue().dtype;
|
Type dtype = operands[0]->getValue().dtype;
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
if (dtype.isa<BFloat16Type, Float64Type>())
|
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
|
||||||
knowledge.dtype = dtype;
|
knowledge.dtype = dtype;
|
||||||
}
|
}
|
||||||
incorporateKnowledge(op->getResult(0), knowledge);
|
incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
|
Loading…
Reference in New Issue