mirror of https://github.com/llvm/torch-mlir
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch adds a new mlir type to correspond to the missing "torch.qint16" type, and enables lowering of quantization-related onnx ops using int16 types. In follow-up patches, custom quantization logic for ops like aten.matmul/aten.mm/aten.convolution may need to be revisited to allow support for qint16. The passes in FuseQuantizedOps.cpp may also need slight modifications.pull/3454/head
parent
7cd3368b20
commit
de28c8540b
|
@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
|
|||
/// Gets the !torch.quint8 typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint16 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.qint16 type
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t);
|
||||
|
||||
/// Gets the !torch.qint16 type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.qint16 typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder,
|
|||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput);
|
||||
|
||||
Type getQTorchTypeFromTorchIntType(Type ty);
|
||||
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);
|
||||
|
||||
template <typename T>
|
||||
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
|
|
|
@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> {
|
||||
let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist";
|
||||
let description = [{
|
||||
Pytorch does not have 16-bit integer quantization support.
|
||||
|
||||
This torch type is added to provide a target for 16-bit quantization
|
||||
schemes coming from imported onnx models.
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
|
||||
let summary = "Type modeling `ScalarType::QUInt8`";
|
||||
let description = [{
|
||||
|
|
|
@ -112,7 +112,8 @@ enum class TypeKind {
|
|||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||
_(c10::qint16, QInt16) /* 27 */
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
|
|
|
@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
|
|||
return wrap(Torch::QUInt8Type::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint16 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchQInt16(MlirType t) {
|
||||
return isa<Torch::QInt16Type>(unwrap(t));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchQInt16TypeGet(MlirContext context) {
|
||||
return wrap(Torch::QInt16Type::get(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirTypeID torchMlirTorchQInt16TypeGetTypeID() {
|
||||
return wrap(Torch::QInt16Type::getTypeID());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
"requires known result dtype");
|
||||
if (scaleTy.getSizes().size() == 0 ||
|
||||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
|
||||
Type qTy = operandTy.getDtype();
|
||||
|
||||
if (qTy.isUnsignedInteger(8)) {
|
||||
qTy = rewriter.getType<Torch::QUInt8Type>();
|
||||
} else if (qTy.isSignedInteger(8)) {
|
||||
qTy = rewriter.getType<Torch::QInt8Type>();
|
||||
} else if (qTy.isSignedInteger(32)) {
|
||||
qTy = rewriter.getType<Torch::QInt32Type>();
|
||||
} else {
|
||||
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
|
||||
if (!qTensorTy) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"unsupported result dtype");
|
||||
}
|
||||
|
||||
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(), qTy);
|
||||
scale = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
||||
zeropoint = rewriter.create<Torch::AtenItemOp>(
|
||||
|
|
|
@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(1.0));
|
||||
|
||||
auto q = [&](Type qty) -> Type {
|
||||
if (qty.isSignedInteger(8))
|
||||
return rewriter.getType<Torch::QInt8Type>();
|
||||
if (qty.isUnsignedInteger(8))
|
||||
return rewriter.getType<Torch::QUInt8Type>();
|
||||
if (qty.isSignedInteger(32))
|
||||
return rewriter.getType<Torch::QInt32Type>();
|
||||
return {};
|
||||
};
|
||||
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
|
||||
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);
|
||||
|
||||
Type lhsQTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
lhsTy.getOptionalSizes(), q(lhsTy.getDtype()));
|
||||
Type rhsQTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
rhsTy.getOptionalSizes(), q(rhsTy.getDtype()));
|
||||
if (!lhsQTy || !rhsQTy)
|
||||
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");
|
||||
|
||||
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
|
||||
|
|
|
@ -177,22 +177,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
"requires known result dtype");
|
||||
|
||||
if (scaleTy.getSizes().size() == 0) {
|
||||
Type qTy = resultType.getDtype();
|
||||
|
||||
if (qTy.isUnsignedInteger(8)) {
|
||||
qTy = rewriter.getType<Torch::QUInt8Type>();
|
||||
} else if (qTy.isSignedInteger(8)) {
|
||||
qTy = rewriter.getType<Torch::QInt8Type>();
|
||||
} else if (qTy.isSignedInteger(32)) {
|
||||
qTy = rewriter.getType<Torch::QInt32Type>();
|
||||
} else {
|
||||
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
|
||||
if (!qTensorTy) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"unsupported result dtype");
|
||||
}
|
||||
|
||||
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(), qTy);
|
||||
auto torchqTy = Torch::getScalarTypeForType(qTy);
|
||||
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
|
||||
|
||||
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
|
@ -311,8 +302,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
||||
c);
|
||||
cTy = dyn_cast<Torch::ValueTensorType>(
|
||||
getQTorchTypeFromTorchIntType(resultType));
|
||||
cTy = getQTorchTypeFromTorchIntType(resultType);
|
||||
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(
|
||||
|
|
|
@ -28,7 +28,8 @@ Value mlir::torch::onnx_c::createConstantIntList(
|
|||
cstValue);
|
||||
}
|
||||
|
||||
Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
|
||||
Torch::ValueTensorType
|
||||
mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
|
||||
Torch::ValueTensorType tty = dyn_cast<Torch::ValueTensorType>(ty);
|
||||
if (!tty)
|
||||
return nullptr;
|
||||
|
@ -40,6 +41,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
|
|||
dty = Torch::QUInt8Type::get(ctx);
|
||||
if (dty.isSignedInteger(8))
|
||||
dty = Torch::QInt8Type::get(ctx);
|
||||
if (dty.isSignedInteger(16))
|
||||
dty = Torch::QInt16Type::get(ctx);
|
||||
if (dty.isSignedInteger(32))
|
||||
dty = Torch::QInt32Type::get(ctx);
|
||||
|
||||
|
|
|
@ -565,6 +565,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) {
|
|||
return false;
|
||||
if (isa<QUInt8Type>(type))
|
||||
return true;
|
||||
if (isa<QInt16Type>(type))
|
||||
return false;
|
||||
if (isa<QInt32Type>(type))
|
||||
return false;
|
||||
if (auto intTy = dyn_cast<IntegerType>(type))
|
||||
|
|
|
@ -185,7 +185,8 @@ static bool isValidTorchDtype(Type dtype) {
|
|||
dtype = cast<ComplexType>(dtype).getElementType();
|
||||
}
|
||||
// Torch quantized types.
|
||||
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt32Type>(dtype))
|
||||
if (isa<Torch::QInt8Type, Torch::QUInt8Type, Torch::QInt16Type,
|
||||
Torch::QInt32Type>(dtype))
|
||||
return true;
|
||||
// Builtin floating point types.
|
||||
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
|
||||
|
@ -463,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
|||
if (isa<QInt8Type>(dtype))
|
||||
return IntegerType::get(context, 8, IntegerType::Signless);
|
||||
|
||||
if (isa<QInt16Type>(dtype))
|
||||
return IntegerType::get(context, 16, IntegerType::Signless);
|
||||
|
||||
if (isa<QInt32Type>(dtype))
|
||||
return IntegerType::get(context, 32, IntegerType::Signless);
|
||||
|
||||
|
|
|
@ -21,10 +21,12 @@ using namespace mlir::torch::Torch;
|
|||
namespace {
|
||||
|
||||
Type getQuantizedType(MLIRContext *context, Type t) {
|
||||
if (t.isSignlessInteger(8))
|
||||
if (t.isSignlessInteger(8) || t.isUnsignedInteger(8))
|
||||
return Torch::QUInt8Type::get(context);
|
||||
if (t.isInteger(8) || t.isSignedInteger(8))
|
||||
return Torch::QInt8Type::get(context);
|
||||
if (t.isInteger(16))
|
||||
return Torch::QInt16Type::get(context);
|
||||
if (t.isInteger(32))
|
||||
return Torch::QInt32Type::get(context);
|
||||
return {};
|
||||
|
|
|
@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) {
|
|||
// Don't forget to extend this when adding new QInt types
|
||||
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
||||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
|
||||
t == ScalarType::QUInt2x4;
|
||||
t == ScalarType::QUInt2x4 || t == ScalarType::QInt16;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -69,6 +69,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::QUInt8;
|
||||
if (isa<QInt8Type>(type))
|
||||
return torch_upstream::ScalarType::QInt8;
|
||||
if (isa<QInt16Type>(type))
|
||||
return torch_upstream::ScalarType::QInt16;
|
||||
if (isa<QInt32Type>(type))
|
||||
return torch_upstream::ScalarType::QInt32;
|
||||
if (isa<ComplexType>(type)) {
|
||||
|
@ -128,6 +130,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
|
|||
return QUInt8Type::get(context);
|
||||
case torch_upstream::ScalarType::QInt8:
|
||||
return QInt8Type::get(context);
|
||||
case torch_upstream::ScalarType::QInt16:
|
||||
return QInt16Type::get(context);
|
||||
case torch_upstream::ScalarType::QInt32:
|
||||
return QInt32Type::get(context);
|
||||
case torch_upstream::ScalarType::ComplexHalf:
|
||||
|
|
|
@ -748,6 +748,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_dequantizelinear_si16
|
||||
func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32>
|
||||
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int
|
||||
// CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]]
|
||||
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]]
|
||||
// CHECK: return %[[DEQ]]
|
||||
return %0 : !torch.vtensor<[6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_dequantizelinear_ui8
|
||||
func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32>
|
||||
|
|
|
@ -171,6 +171,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %
|
|||
|
||||
func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8>
|
||||
func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8>
|
||||
func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16>
|
||||
|
||||
func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>> {
|
||||
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>
|
||||
|
|
Loading…
Reference in New Issue