TorchToTosa: Support casts from and to bf16 (#2118)

pull/2027/head
Matthias Gehre 2023-05-13 00:18:23 +02:00 committed by GitHub
parent 17db2aafa3
commit 3a8196588f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 2 deletions

View File

@ -222,21 +222,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
} }
static LogicalResult checkValidityOfCast(Type src, Type dest) { static LogicalResult checkValidityOfCast(Type src, Type dest) {
if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || if ((src == dest) ||
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) || (src.isInteger(64) && dest.isF32()) ||
(src.isInteger(32) && dest.isInteger(64)) || (src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) || (src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isBF16()) ||
(src.isInteger(16) && dest.isBF16()) ||
(src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
(src.isInteger(1) && dest.isInteger(64)) || (src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) || (src.isInteger(1) && dest.isF32()) ||
(src.isF32() && dest.isF64()) || (src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isBF16()) ||
(src.isF64() && dest.isF32()) || (src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
(src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1))) { (src.isF32() && dest.isInteger(1)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
return success(); return success();
} }
return failure(); return failure();