Add support for ScalarType::QUInt8

I ran into this while poking around at
https://github.com/llvm/torch-mlir/issues/959
pull/996/head
Sean Silva 2022-06-25 00:20:26 +00:00
parent cd79538a0c
commit 227dea7b2e
8 changed files with 54 additions and 1 deletions

View File

@ -201,6 +201,7 @@ $TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonic
``` ```
Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs. Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs.
# Updating the LLVM submodule # Updating the LLVM submodule
Torch-MLIR maintains `llvm-project` (which contains, among other things, Torch-MLIR maintains `llvm-project` (which contains, among other things,

View File

@ -145,6 +145,16 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type. /// Gets the !torch.qint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.quint8 type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.quint8 type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
/// Gets the !torch.quint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -130,6 +130,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
| torch.int64 | si64 | | torch.int64 | si64 |
| torch.bool | i1 | | torch.bool | i1 |
| torch.qint8 | !torch.qint8 | | torch.qint8 | !torch.qint8 |
| torch.quint8 | !torch.quint8 |
|-------------------|--------------------| |-------------------|--------------------|
``` ```
@ -295,6 +296,17 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
}]; }];
} }
def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
let summary = "Type modeling `ScalarType::QUInt8`";
let description = [{
This is intended to be a 1:1 match for the Torch `ScalarType` types.
Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of
types, it is deemed preferable to import them as one-off ad-hoc types
instead of a single parameterized type.
}];
}
def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
let summary = "Torch packed linear params type"; let summary = "Torch packed linear params type";
let description = [{ let description = [{

View File

@ -174,6 +174,18 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context))); return wrap(Torch::QInt8Type::get(unwrap(context)));
} }
//===----------------------------------------------------------------------===//
// torch.quint8 type.
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
return unwrap(t).isa<Torch::QUInt8Type>();
}
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
return wrap(Torch::QUInt8Type::get(unwrap(context)));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -147,7 +147,7 @@ void Torch::UnionType::print(AsmPrinter &printer) const {
static bool isValidTorchDtype(Type dtype) { static bool isValidTorchDtype(Type dtype) {
// Torch quantized types. // Torch quantized types.
if (dtype.isa<Torch::QInt8Type>()) if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type>())
return true; return true;
// Builtin floating point types. // Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>()) if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())

View File

@ -52,6 +52,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
return mlirF16TypeGet(context); return mlirF16TypeGet(context);
case ScalarType::QInt8: case ScalarType::QInt8:
return torchMlirTorchQInt8TypeGet(context); return torchMlirTorchQInt8TypeGet(context);
case ScalarType::QUInt8:
return torchMlirTorchQUInt8TypeGet(context);
default: { default: {
return {nullptr}; return {nullptr};
} }
@ -328,6 +330,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
case ScalarType::QInt8: case ScalarType::QInt8:
return mlirDenseElementsAttrInt8Get( return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t *>(tensorData)); shapedType, numElements, static_cast<const int8_t *>(tensorData));
case ScalarType::QUInt8:
return mlirDenseElementsAttrUInt8Get(
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
case ScalarType::BFloat16: case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get( return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData)); shapedType, numElements, static_cast<const uint16_t *>(tensorData));

View File

@ -146,3 +146,6 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor %0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
return return
} }
func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8>
func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8>

View File

@ -22,6 +22,8 @@ class TestModule(torch.nn.Module):
self.ones_f64 = torch.ones(1, dtype=torch.float64) self.ones_f64 = torch.ones(1, dtype=torch.float64)
self.ones_bool = torch.ones(1, dtype=torch.bool) self.ones_bool = torch.ones(1, dtype=torch.bool)
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
self.arange = torch.nn.Parameter(torch.arange(3.0)) self.arange = torch.nn.Parameter(torch.arange(3.0))
# CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32> # CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
@ -32,6 +34,12 @@ class TestModule(torch.nn.Module):
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64> # CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1> # CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],si8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.qint8>
# CHECK: %[[ONES_QUINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8>
# CHECK: %[[ONES_QUINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QUINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],ui8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.quint8>
# CHECK: %[[ROOT:.*]] = torch.nn_module { # CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32> # CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32> # CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>
@ -41,6 +49,8 @@ class TestModule(torch.nn.Module):
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64> # CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> # CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
# CHECK: } # CHECK: }