mirror of https://github.com/llvm/torch-mlir
TorchToTosa: Support casts from and to bf16 (#2118)
parent
17db2aafa3
commit
3a8196588f
|
@ -222,21 +222,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
||||
if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) ||
|
||||
if ((src == dest) ||
|
||||
(src.isInteger(64) && dest.isInteger(32)) ||
|
||||
(src.isInteger(64) && dest.isInteger(8)) ||
|
||||
(src.isInteger(64) && dest.isInteger(1)) ||
|
||||
(src.isInteger(64) && dest.isF32()) ||
|
||||
(src.isInteger(32) && dest.isInteger(64)) ||
|
||||
(src.isInteger(32) && dest.isInteger(1)) ||
|
||||
(src.isInteger(32) && dest.isF32()) ||
|
||||
(src.isInteger(32) && dest.isBF16()) ||
|
||||
(src.isInteger(16) && dest.isBF16()) ||
|
||||
(src.isInteger(8) && dest.isInteger(1)) ||
|
||||
(src.isInteger(8) && dest.isBF16()) ||
|
||||
(src.isInteger(1) && dest.isInteger(64)) ||
|
||||
(src.isInteger(1) && dest.isF32()) ||
|
||||
(src.isF32() && dest.isF64()) ||
|
||||
(src.isF32() && dest.isBF16()) ||
|
||||
(src.isF64() && dest.isF32()) ||
|
||||
(src.isF64() && dest.isBF16()) ||
|
||||
(src.isF32() && dest.isInteger(8)) ||
|
||||
(src.isF32() && dest.isInteger(64)) ||
|
||||
(src.isF32() && dest.isInteger(1))) {
|
||||
(src.isF32() && dest.isInteger(1)) ||
|
||||
(src.isBF16() && dest.isInteger(8)) ||
|
||||
(src.isBF16() && dest.isInteger(16)) ||
|
||||
(src.isBF16() && dest.isInteger(32)) ||
|
||||
(src.isBF16() && dest.isF32())) {
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
|
Loading…
Reference in New Issue