From 227dea7b2e71725af5aafca4c556c95c969a48a8 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Sat, 25 Jun 2022 00:20:26 +0000 Subject: [PATCH] Add support for ScalarType::QUInt8 I ran into this while poking around at https://github.com/llvm/torch-mlir/issues/959 --- development.md | 1 + include/torch-mlir-c/TorchTypes.h | 10 ++++++++++ include/torch-mlir/Dialect/Torch/IR/TorchTypes.td | 12 ++++++++++++ lib/CAPI/TorchTypes.cpp | 12 ++++++++++++ lib/Dialect/Torch/IR/TorchTypes.cpp | 2 +- .../importer/jit_ir/csrc/torch_to_mlir_utils.cpp | 5 +++++ test/Dialect/Torch/ops.mlir | 3 +++ test/python/importer/jit_ir/ivalue_import/tensors.py | 10 ++++++++++ 8 files changed, 54 insertions(+), 1 deletion(-) diff --git a/development.md b/development.md index 5e1b50df4..892df93b7 100644 --- a/development.md +++ b/development.md @@ -201,6 +201,7 @@ $TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonic ``` Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs. + # Updating the LLVM submodule Torch-MLIR maintains `llvm-project` (which contains, among other things, diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 586468ba7..8cff8da86 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -145,6 +145,16 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); /// Gets the !torch.qint8 type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); +//===----------------------------------------------------------------------===// +// torch.quint8 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.quint8 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); + +/// Gets the !torch.quint8 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 6bc2b388e..e0780395e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -130,6 +130,7 @@ class AnyTorchTensorType | torch.int64 | si64 | | torch.bool | i1 | | torch.qint8 | !torch.qint8 | + | torch.quint8 | !torch.quint8 | |-------------------|--------------------| ``` @@ -295,6 +296,17 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> { }]; } +def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { + let summary = "Type modeling `ScalarType::QUInt8`"; + let description = [{ + This is intended to be a 1:1 match for the Torch `ScalarType` types. + + Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of + types, it is deemed preferable to import them as one-off ad-hoc types + instead of a single parameterized type. + }]; +} + def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { let summary = "Torch packed linear params type"; let description = [{ diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index ff01b9518..0c67453f3 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -174,6 +174,18 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { return wrap(Torch::QInt8Type::get(unwrap(context))); } +//===----------------------------------------------------------------------===// +// torch.quint8 type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchQUInt8(MlirType t) { + return unwrap(t).isa(); +} + +MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { + return wrap(Torch::QUInt8Type::get(unwrap(context))); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index bf4046cbc..10e2008ad 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -147,7 +147,7 @@ void Torch::UnionType::print(AsmPrinter &printer) const { static bool isValidTorchDtype(Type dtype) { // Torch quantized types. - if (dtype.isa()) + if (dtype.isa()) return true; // Builtin floating point types. if (dtype.isa()) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index dfe6db113..289b95c45 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -52,6 +52,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, return mlirF16TypeGet(context); case ScalarType::QInt8: return torchMlirTorchQInt8TypeGet(context); + case ScalarType::QUInt8: + return torchMlirTorchQUInt8TypeGet(context); default: { return {nullptr}; } @@ -328,6 +330,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, case ScalarType::QInt8: return mlirDenseElementsAttrInt8Get( shapedType, numElements, static_cast(tensorData)); + case ScalarType::QUInt8: + return mlirDenseElementsAttrUInt8Get( + shapedType, numElements, static_cast(tensorData)); case ScalarType::BFloat16: return mlirDenseElementsAttrBFloat16Get( shapedType, numElements, static_cast(tensorData)); diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index bf66318ea..9712c6b64 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -146,3 +146,6 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % %0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list, !torch.union -> !torch.tensor return } + +func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> +func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> diff --git a/test/python/importer/jit_ir/ivalue_import/tensors.py b/test/python/importer/jit_ir/ivalue_import/tensors.py index ba18de31b..02a82f06e 100644 --- a/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -22,6 +22,8 @@ class TestModule(torch.nn.Module): self.ones_f64 = torch.ones(1, dtype=torch.float64) self.ones_bool = torch.ones(1, dtype=torch.bool) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) + self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) + self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8) self.arange = torch.nn.Parameter(torch.arange(3.0)) # CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32> @@ -32,6 +34,12 @@ class TestModule(torch.nn.Module): # CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64> # CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense : tensor<1xi1>) : !torch.tensor<[1],i1> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> +# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8> +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 +# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],si8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.qint8> +# CHECK: %[[ONES_QUINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8> +# CHECK: %[[ONES_QUINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QUINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],ui8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.quint8> # CHECK: %[[ROOT:.*]] = torch.nn_module { # CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32> # CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32> @@ -41,6 +49,8 @@ class TestModule(torch.nn.Module): # CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64> # CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> +# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8> +# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8> # CHECK: }