mirror of https://github.com/llvm/torch-mlir
Add support for ScalarType::QUInt8
I ran into this while poking around at https://github.com/llvm/torch-mlir/issues/959pull/996/head
parent
cd79538a0c
commit
227dea7b2e
|
@ -201,6 +201,7 @@ $TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonic
|
||||||
```
|
```
|
||||||
|
|
||||||
Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs.
|
Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs.
|
||||||
|
|
||||||
# Updating the LLVM submodule
|
# Updating the LLVM submodule
|
||||||
|
|
||||||
Torch-MLIR maintains `llvm-project` (which contains, among other things,
|
Torch-MLIR maintains `llvm-project` (which contains, among other things,
|
||||||
|
|
|
@ -145,6 +145,16 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
|
||||||
/// Gets the !torch.qint8 type.
|
/// Gets the !torch.qint8 type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.quint8 type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.quint8 type
|
||||||
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.quint8 type.
|
||||||
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -130,6 +130,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
| torch.int64 | si64 |
|
| torch.int64 | si64 |
|
||||||
| torch.bool | i1 |
|
| torch.bool | i1 |
|
||||||
| torch.qint8 | !torch.qint8 |
|
| torch.qint8 | !torch.qint8 |
|
||||||
|
| torch.quint8 | !torch.quint8 |
|
||||||
|-------------------|--------------------|
|
|-------------------|--------------------|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -295,6 +296,17 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
|
||||||
|
let summary = "Type modeling `ScalarType::QUInt8`";
|
||||||
|
let description = [{
|
||||||
|
This is intended to be a 1:1 match for the Torch `ScalarType` types.
|
||||||
|
|
||||||
|
Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of
|
||||||
|
types, it is deemed preferable to import them as one-off ad-hoc types
|
||||||
|
instead of a single parameterized type.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
||||||
let summary = "Torch packed linear params type";
|
let summary = "Torch packed linear params type";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -174,6 +174,18 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.quint8 type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::QUInt8Type>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::QUInt8Type::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -147,7 +147,7 @@ void Torch::UnionType::print(AsmPrinter &printer) const {
|
||||||
|
|
||||||
static bool isValidTorchDtype(Type dtype) {
|
static bool isValidTorchDtype(Type dtype) {
|
||||||
// Torch quantized types.
|
// Torch quantized types.
|
||||||
if (dtype.isa<Torch::QInt8Type>())
|
if (dtype.isa<Torch::QInt8Type, Torch::QUInt8Type>())
|
||||||
return true;
|
return true;
|
||||||
// Builtin floating point types.
|
// Builtin floating point types.
|
||||||
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
||||||
|
|
|
@ -52,6 +52,8 @@ static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context,
|
||||||
return mlirF16TypeGet(context);
|
return mlirF16TypeGet(context);
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return torchMlirTorchQInt8TypeGet(context);
|
return torchMlirTorchQInt8TypeGet(context);
|
||||||
|
case ScalarType::QUInt8:
|
||||||
|
return torchMlirTorchQUInt8TypeGet(context);
|
||||||
default: {
|
default: {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
@ -328,6 +330,9 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return mlirDenseElementsAttrInt8Get(
|
return mlirDenseElementsAttrInt8Get(
|
||||||
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
||||||
|
case ScalarType::QUInt8:
|
||||||
|
return mlirDenseElementsAttrUInt8Get(
|
||||||
|
shapedType, numElements, static_cast<const uint8_t *>(tensorData));
|
||||||
case ScalarType::BFloat16:
|
case ScalarType::BFloat16:
|
||||||
return mlirDenseElementsAttrBFloat16Get(
|
return mlirDenseElementsAttrBFloat16Get(
|
||||||
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
|
||||||
|
|
|
@ -146,3 +146,6 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %
|
||||||
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
|
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8>
|
||||||
|
func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8>
|
||||||
|
|
|
@ -22,6 +22,8 @@ class TestModule(torch.nn.Module):
|
||||||
self.ones_f64 = torch.ones(1, dtype=torch.float64)
|
self.ones_f64 = torch.ones(1, dtype=torch.float64)
|
||||||
self.ones_bool = torch.ones(1, dtype=torch.bool)
|
self.ones_bool = torch.ones(1, dtype=torch.bool)
|
||||||
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
|
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
|
||||||
|
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||||
|
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
|
||||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||||
|
|
||||||
# CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
# CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
||||||
|
@ -32,6 +34,12 @@ class TestModule(torch.nn.Module):
|
||||||
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
|
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
|
||||||
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
|
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
|
||||||
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
|
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
|
||||||
|
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
|
||||||
|
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||||
|
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],si8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.qint8>
|
||||||
|
# CHECK: %[[ONES_QUINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8>
|
||||||
|
# CHECK: %[[ONES_QUINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QUINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[1],ui8>, !torch.float, !torch.int -> !torch.tensor<[1],!torch.quint8>
|
||||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
|
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
|
||||||
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>
|
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>
|
||||||
|
@ -41,6 +49,8 @@ class TestModule(torch.nn.Module):
|
||||||
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
|
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
|
||||||
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
|
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
|
||||||
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
|
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
|
||||||
|
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
|
||||||
|
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue