importer: add initial support for loading Float16 tensors (#1169)

follow up #761:

    This patch updates the `torch_mlir::convertTensorToMlirElementsAttr()`
    method to enable the creation of tensors whose base type is Float16.
    This patch also adds a test to validate the IR generation, and it
    updates the test for importing tensors of various types.
pull/1171/head snapshot-20220808.558
Tanyo Kwok 2022-08-08 12:37:31 +08:00 committed by GitHub
parent 1ee865983b
commit 290d7755fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 0 deletions

View File

@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Bool;
if (type.isBF16())
return torch_upstream::ScalarType::BFloat16;
if (type.isF16())
return torch_upstream::ScalarType::Half;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType(
return IntegerType::get(context, 1);
case torch_upstream::ScalarType::BFloat16:
return mlir::FloatType::getBF16(context);
case torch_upstream::ScalarType::Half:
return mlir::FloatType::getF16(context);
default:
return Type();
}

View File

@ -353,6 +353,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
default:
throwUnsupportedTensorError();
}

View File

@ -220,3 +220,16 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16>
return %0 : !torch.vtensor<[?,?],bf16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.neg.f16
// CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %{{.*}}: f16):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f16
// CHECK-NEXT: linalg.yield %[[NEG]] : f16
// CHECK-NEXT: } -> tensor<?x?xf16>
func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}

View File

@ -22,6 +22,7 @@ class TestModule(torch.nn.Module):
self.ones_f64 = torch.ones(1, dtype=torch.float64)
self.ones_bool = torch.ones(1, dtype=torch.bool)
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
self.ones_f16 = torch.ones(1, dtype=torch.half)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
self.arange = torch.nn.Parameter(torch.arange(3.0))
@ -34,6 +35,7 @@ class TestModule(torch.nn.Module):
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
@ -49,6 +51,7 @@ class TestModule(torch.nn.Module):
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
# CHECK: }