mirror of https://github.com/llvm/torch-mlir
importer: add initial support for loading Float16 tensors (#1169)
follow up #761: This patch updates the `torch_mlir::convertTensorToMlirElementsAttr()` method to enable the creation of tensors whose base type is Float16. This patch also adds a test to validate the IR generation, and it updates the test for importing tensors of various types.pull/1171/head snapshot-20220808.558
parent
1ee865983b
commit
290d7755fb
|
@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::Bool;
|
||||
if (type.isBF16())
|
||||
return torch_upstream::ScalarType::BFloat16;
|
||||
if (type.isF16())
|
||||
return torch_upstream::ScalarType::Half;
|
||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
}
|
||||
|
||||
|
@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType(
|
|||
return IntegerType::get(context, 1);
|
||||
case torch_upstream::ScalarType::BFloat16:
|
||||
return mlir::FloatType::getBF16(context);
|
||||
case torch_upstream::ScalarType::Half:
|
||||
return mlir::FloatType::getF16(context);
|
||||
default:
|
||||
return Type();
|
||||
}
|
||||
|
|
|
@ -353,6 +353,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
|||
case ScalarType::BFloat16:
|
||||
return mlirDenseElementsAttrBFloat16Get(
|
||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
||||
case ScalarType::Half:
|
||||
return mlirDenseElementsAttrFloat16Get(
|
||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
||||
|
||||
default:
|
||||
throwUnsupportedTensorError();
|
||||
}
|
||||
|
|
|
@ -220,3 +220,16 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten
|
|||
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16>
|
||||
return %0 : !torch.vtensor<[?,?],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.neg.f16
|
||||
// CHECK: linalg.generic {{.*}} {
|
||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %{{.*}}: f16):
|
||||
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f16
|
||||
// CHECK-NEXT: linalg.yield %[[NEG]] : f16
|
||||
// CHECK-NEXT: } -> tensor<?x?xf16>
|
||||
func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> {
|
||||
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
|
||||
return %0 : !torch.vtensor<[?,?],f16>
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ class TestModule(torch.nn.Module):
|
|||
self.ones_f64 = torch.ones(1, dtype=torch.float64)
|
||||
self.ones_bool = torch.ones(1, dtype=torch.bool)
|
||||
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
|
||||
self.ones_f16 = torch.ones(1, dtype=torch.half)
|
||||
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
@ -34,6 +35,7 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
|
||||
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
|
||||
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
|
||||
# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16>
|
||||
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
|
||||
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||
|
@ -49,6 +51,7 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
|
||||
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
|
||||
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
|
||||
# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16>
|
||||
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
|
||||
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
|
||||
# CHECK: }
|
||||
|
|
Loading…
Reference in New Issue