diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 09be73436..fa56ca39d 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -222,21 +222,31 @@ std::optional getConstTensor(PatternRewriter &rewriter, } static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || + if ((src == dest) || + (src.isInteger(64) && dest.isInteger(32)) || (src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isF32()) || (src.isInteger(32) && dest.isInteger(64)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || + (src.isInteger(32) && dest.isBF16()) || + (src.isInteger(16) && dest.isBF16()) || (src.isInteger(8) && dest.isInteger(1)) || + (src.isInteger(8) && dest.isBF16()) || (src.isInteger(1) && dest.isInteger(64)) || (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || + (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(64)) || - (src.isF32() && dest.isInteger(1))) { + (src.isF32() && dest.isInteger(1)) || + (src.isBF16() && dest.isInteger(8)) || + (src.isBF16() && dest.isInteger(16)) || + (src.isBF16() && dest.isInteger(32)) || + (src.isBF16() && dest.isF32())) { return success(); } return failure();