diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 7e9f9a947..906243e14 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Bool; if (type.isBF16()) return torch_upstream::ScalarType::BFloat16; + if (type.isF16()) + return torch_upstream::ScalarType::Half; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } @@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType( return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: return mlir::FloatType::getBF16(context); + case torch_upstream::ScalarType::Half: + return mlir::FloatType::getF16(context); default: return Type(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 32c2e157c..471e2d570 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -353,6 +353,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, case ScalarType::BFloat16: return mlirDenseElementsAttrBFloat16Get( shapedType, numElements, static_cast(tensorData)); + case ScalarType::Half: + return mlirDenseElementsAttrFloat16Get( + shapedType, numElements, static_cast(tensorData)); + default: throwUnsupportedTensorError(); } diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 1e0f66582..dd07e3848 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -220,3 +220,16 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16> return %0 : !torch.vtensor<[?,?],bf16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.neg.f16 +// CHECK: linalg.generic {{.*}} { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %{{.*}}: f16): +// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f16 +// CHECK-NEXT: linalg.yield %[[NEG]] : f16 +// CHECK-NEXT: } -> tensor +func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { + %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> + return %0 : !torch.vtensor<[?,?],f16> +} diff --git a/test/python/importer/jit_ir/ivalue_import/tensors.py b/test/python/importer/jit_ir/ivalue_import/tensors.py index 02a82f06e..314c245f5 100644 --- a/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -22,6 +22,7 @@ class TestModule(torch.nn.Module): self.ones_f64 = torch.ones(1, dtype=torch.float64) self.ones_bool = torch.ones(1, dtype=torch.bool) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) + self.ones_f16 = torch.ones(1, dtype=torch.half) self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8) self.arange = torch.nn.Parameter(torch.arange(3.0)) @@ -34,6 +35,7 @@ class TestModule(torch.nn.Module): # CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64> # CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense : tensor<1xi1>) : !torch.tensor<[1],i1> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> +# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16> # CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8> # CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 # CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 @@ -49,6 +51,7 @@ class TestModule(torch.nn.Module): # CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64> # CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> +# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16> # CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8> # CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8> # CHECK: }