allow tosa.cast to convert from f32 to f16 (#2934)

According to the [official TOSA
spec](https://www.mlplatform.org/tosa/tosa_spec.html#_cast), `tosa.cast`
allows a cast from `fp32` to `fp16`. We were not previously accounting
for this in the `TorchToTosa` lowering.

Also did a tiny bit of cleanup in the code to make it easier to spot
which conversions are currently allowed.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
pull/2908/head
Srinath Avadhanula 2024-02-20 17:22:38 -05:00 committed by GitHub
parent 534b266f2d
commit 0f80e75c2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 5 deletions

View File

@ -266,28 +266,44 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
}
static LogicalResult checkValidityOfCast(Type src, Type dest) {
if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) ||
// clang-format off
if ((src == dest) ||
// int64 -> *
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isBF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) ||
(src.isInteger(1) && dest.isF32()) ||
// f64 -> *
(src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
// f32 -> *
(src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) {
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
return success();
}
// clang-format on
return failure();
}

View File

@ -0,0 +1,12 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file
// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16>
func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> {
%int5 = torch.constant.int 5
%false = torch.constant.bool false
%none = torch.constant.none
%out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16>
return %out : !torch.vtensor<[1,32,220,220],f16>
}