From 00fc14a6e1bdafa5b68d71f6c1f762e6460e1910 Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Fri, 27 Jan 2023 09:21:06 -0800 Subject: [PATCH] [TOSA] Add to.dtype i1 to i64 (#1830) --- .../TorchToTosa/TosaLegalizeUtils.cpp | 1 + test/Conversion/TorchToTosa/basic.mlir | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 0ab8f0a4e..6ecb3f0ca 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -230,6 +230,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || (src.isInteger(8) && dest.isInteger(1)) || + (src.isInteger(1) && dest.isInteger(64)) || (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(1))) { return success(); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index dc037353e..73b2cbeab 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -939,6 +939,25 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten return %0 : !torch.vtensor<[3,5],i1> } +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,128],i1> -> tensor<1x128xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x128xi1>) -> tensor<1x128xi64> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> +// CHECK: } +func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { + %int4 = torch.constant.int 4 + %none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[1,128],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128],si64> + return %0 : !torch.vtensor<[1,128],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,