mirror of https://github.com/llvm/torch-mlir
allow tosa.cast to convert from f32 to f16 (#2934)
According to the [official TOSA spec](https://www.mlplatform.org/tosa/tosa_spec.html#_cast), `tosa.cast` allows a cast from `fp32` to `fp16`. We were not previously accounting for this in the `TorchToTosa` lowering. Also did a tiny bit of cleanup in the code to make it easier to spot which conversions are currently allowed. --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>pull/2908/head
parent
534b266f2d
commit
0f80e75c2e
|
@ -266,28 +266,44 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
||||
if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) ||
|
||||
// clang-format off
|
||||
if ((src == dest) ||
|
||||
// int64 -> *
|
||||
(src.isInteger(64) && dest.isInteger(32)) ||
|
||||
(src.isInteger(64) && dest.isInteger(8)) ||
|
||||
(src.isInteger(64) && dest.isInteger(1)) ||
|
||||
(src.isInteger(64) && dest.isF32()) ||
|
||||
// int32 -> *
|
||||
(src.isInteger(32) && dest.isInteger(64)) ||
|
||||
(src.isInteger(32) && dest.isInteger(1)) ||
|
||||
(src.isInteger(32) && dest.isF32()) ||
|
||||
(src.isInteger(32) && dest.isBF16()) ||
|
||||
// int16 -> *
|
||||
(src.isInteger(16) && dest.isBF16()) ||
|
||||
// int8 -> *
|
||||
(src.isInteger(8) && dest.isInteger(1)) ||
|
||||
(src.isInteger(8) && dest.isBF16()) ||
|
||||
// int1 -> *
|
||||
(src.isInteger(1) && dest.isInteger(64)) ||
|
||||
(src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) ||
|
||||
(src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) ||
|
||||
(src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) ||
|
||||
(src.isInteger(1) && dest.isF32()) ||
|
||||
// f64 -> *
|
||||
(src.isF64() && dest.isF32()) ||
|
||||
(src.isF64() && dest.isBF16()) ||
|
||||
// f32 -> *
|
||||
(src.isF32() && dest.isF64()) ||
|
||||
(src.isF32() && dest.isBF16()) ||
|
||||
(src.isF32() && dest.isF16()) ||
|
||||
(src.isF32() && dest.isInteger(8)) ||
|
||||
(src.isF32() && dest.isInteger(64)) ||
|
||||
(src.isF32() && dest.isInteger(1)) ||
|
||||
// bf16 -> *
|
||||
(src.isBF16() && dest.isInteger(8)) ||
|
||||
(src.isBF16() && dest.isInteger(16)) ||
|
||||
(src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) {
|
||||
(src.isBF16() && dest.isInteger(32)) ||
|
||||
(src.isBF16() && dest.isF32())) {
|
||||
return success();
|
||||
}
|
||||
// clang-format on
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file
|
||||
|
||||
// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16>
|
||||
func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> {
|
||||
%int5 = torch.constant.int 5
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16>
|
||||
return %out : !torch.vtensor<[1,32,220,220],f16>
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue