[ONNX] add int16 quantization support (#3446)

There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
pull/3454/head
zjgarvey 2024-06-12 00:07:22 -05:00 committed by GitHub
parent 7cd3368b20
commit de28c8540b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 85 additions and 44 deletions

View File

@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);
//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.qint16 type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t);
/// Gets the !torch.qint16 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context);
/// Gets the !torch.qint16 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void);
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//

View File

@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);
Type getQTorchTypeFromTorchIntType(Type ty);
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,

View File

@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
}];
}
def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> {
let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist";
let description = [{
Pytorch does not have 16-bit integer quantization support.
This torch type is added to provide a target for 16-bit quantization
schemes coming from imported onnx models.
}];
}
def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
let summary = "Type modeling `ScalarType::QUInt8`";
let description = [{

View File

@ -112,7 +112,8 @@ enum class TypeKind {
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(c10::qint16, QInt16) /* 27 */
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,

View File

@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
return wrap(Torch::QUInt8Type::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchQInt16(MlirType t) {
return isa<Torch::QInt16Type>(unwrap(t));
}
MlirType torchMlirTorchQInt16TypeGet(MlirContext context) {
return wrap(Torch::QInt16Type::get(unwrap(context)));
}
MlirTypeID torchMlirTorchQInt16TypeGetTypeID() {
return wrap(Torch::QInt16Type::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//

View File

@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"requires known result dtype");
if (scaleTy.getSizes().size() == 0 ||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
Type qTy = operandTy.getDtype();
if (qTy.isUnsignedInteger(8)) {
qTy = rewriter.getType<Torch::QUInt8Type>();
} else if (qTy.isSignedInteger(8)) {
qTy = rewriter.getType<Torch::QInt8Type>();
} else if (qTy.isSignedInteger(32)) {
qTy = rewriter.getType<Torch::QInt32Type>();
} else {
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), qTy);
scale = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(

View File

@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));
auto q = [&](Type qty) -> Type {
if (qty.isSignedInteger(8))
return rewriter.getType<Torch::QInt8Type>();
if (qty.isUnsignedInteger(8))
return rewriter.getType<Torch::QUInt8Type>();
if (qty.isSignedInteger(32))
return rewriter.getType<Torch::QInt32Type>();
return {};
};
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);
Type lhsQTy = rewriter.getType<Torch::ValueTensorType>(
lhsTy.getOptionalSizes(), q(lhsTy.getDtype()));
Type rhsQTy = rewriter.getType<Torch::ValueTensorType>(
rhsTy.getOptionalSizes(), q(rhsTy.getDtype()));
if (!lhsQTy || !rhsQTy)
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);

View File

@ -177,22 +177,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"requires known result dtype");
if (scaleTy.getSizes().size() == 0) {
Type qTy = resultType.getDtype();
if (qTy.isUnsignedInteger(8)) {
qTy = rewriter.getType<Torch::QUInt8Type>();
} else if (qTy.isSignedInteger(8)) {
qTy = rewriter.getType<Torch::QInt8Type>();
} else if (qTy.isSignedInteger(32)) {
qTy = rewriter.getType<Torch::QInt32Type>();
} else {
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), qTy);
auto torchqTy = Torch::getScalarTypeForType(qTy);
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@ -311,8 +302,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = dyn_cast<Torch::ValueTensorType>(
getQTorchTypeFromTorchIntType(resultType));
cTy = getQTorchTypeFromTorchIntType(resultType);
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(

View File

@ -28,7 +28,8 @@ Value mlir::torch::onnx_c::createConstantIntList(
cstValue);
}
Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
Torch::ValueTensorType
mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
Torch::ValueTensorType tty = dyn_cast<Torch::ValueTensorType>(ty);
if (!tty)
return nullptr;
@ -40,6 +41,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
dty = Torch::QUInt8Type::get(ctx);
if (dty.isSignedInteger(8))
dty = Torch::QInt8Type::get(ctx);
if (dty.isSignedInteger(16))
dty = Torch::QInt16Type::get(ctx);
if (dty.isSignedInteger(32))
dty = Torch::QInt32Type::get(ctx);

View File

@ -565,6 +565,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) {
return false;
if (isa<QUInt8Type>(type))
return true;
if (isa<QInt16Type>(type))
return false;
if (isa<QInt32Type>(type))
return false;
if (auto intTy = dyn_cast<IntegerType>(type))

View File

@ -185,7 +185,8 @@ static bool isValidTorchDtype(Type dtype) {
dtype = cast<ComplexType>(dtype).getElementType();
}
// Torch quantized types.
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>(dtype))
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt16Type,
Torch::QInt32Type>(dtype))
return true;
// Builtin floating point types.
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
@ -463,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (isa<QInt8Type>(dtype))
return IntegerType::get(context, 8, IntegerType::Signless);
if (isa<QInt16Type>(dtype))
return IntegerType::get(context, 16, IntegerType::Signless);
if (isa<QInt32Type>(dtype))
return IntegerType::get(context, 32, IntegerType::Signless);

View File

@ -21,10 +21,12 @@ using namespace mlir::torch::Torch;
namespace {
Type getQuantizedType(MLIRContext *context, Type t) {
if (t.isSignlessInteger(8))
if (t.isSignlessInteger(8) || t.isUnsignedInteger(8))
return Torch::QUInt8Type::get(context);
if (t.isInteger(8) || t.isSignedInteger(8))
return Torch::QInt8Type::get(context);
if (t.isInteger(16))
return Torch::QInt16Type::get(context);
if (t.isInteger(32))
return Torch::QInt32Type::get(context);
return {};

View File

@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
t == ScalarType::QUInt2x4;
t == ScalarType::QUInt2x4 || t == ScalarType::QInt16;
}
//===----------------------------------------------------------------------===//

View File

@ -69,6 +69,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::QUInt8;
if (isa<QInt8Type>(type))
return torch_upstream::ScalarType::QInt8;
if (isa<QInt16Type>(type))
return torch_upstream::ScalarType::QInt16;
if (isa<QInt32Type>(type))
return torch_upstream::ScalarType::QInt32;
if (isa<ComplexType>(type)) {
@ -128,6 +130,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
return QUInt8Type::get(context);
case torch_upstream::ScalarType::QInt8:
return QInt8Type::get(context);
case torch_upstream::ScalarType::QInt16:
return QInt16Type::get(context);
case torch_upstream::ScalarType::QInt32:
return QInt32Type::get(context);
case torch_upstream::ScalarType::ComplexHalf:

View File

@ -748,6 +748,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor
// -----
// CHECK-LABEL: @test_dequantizelinear_si16
func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32>
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int
// CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]]
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]]
// CHECK: return %[[DEQ]]
return %0 : !torch.vtensor<[6],f32>
}
// -----
// CHECK-LABEL: @test_dequantizelinear_ui8
func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32>

View File

@ -171,6 +171,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %
func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8>
func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8>
func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16>
func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>> {
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>