[torch] Basic support for per-channel quantized graphs (#3623)

This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.

Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
pull/3292/merge
Felix Schneider 2024-08-10 15:51:09 +02:00 committed by GitHub
parent 44266ab0c4
commit 0314188dbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 215 additions and 35 deletions

View File

@ -2350,6 +2350,10 @@ public:
} else if (zeropointDTy.isSignedInteger(8)) {
zeropoint =
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
} else if (zeropointDTy.isInteger(64)) {
zeropoint =
b.create<arith::TruncIOp>(loc, b.getI32Type(), zeropoint);
op->emitWarning() << "truncated zero point from 64 to 32 bit";
}
Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);

View File

@ -44,6 +44,11 @@ bool isQCommutingOp(mlir::Operation *op) {
op);
}
struct QuantizedChain {
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
};
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... ->
// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops
@ -58,10 +63,8 @@ public:
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
llvm::SmallVector<Value> operands(op->getOperands());
bool dequanted = false;
// Prevent fusion for 1d convolution ops and just do it as an f32 conv since
// there isn't a linalg named op for quantized 1-d convolution yet.
@ -72,10 +75,10 @@ public:
return rewriter.notifyMatchFailure(
op, "1-d quantized convolution is not supported");
SmallVector<QuantizedChain, 2> operandChains;
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
QuantizedChain chain;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
@ -83,40 +86,59 @@ public:
break;
// Case 1 : currOp is a q commuting op (continue loop)
if (isQCommutingOp(currOp)) {
commutingOpStack.push(currOp);
chain.commutingOpStack.push(currOp);
// set operand to currOp for next k-iteration
operand = currOp->getOperand(0);
continue;
}
// Case 2 : currOp is a dequant op (end loop)
if (llvm::isa<AtenDequantizeSelfOp, AtenDequantizeTensorOp>(currOp)) {
dequantOpd = currOp->getOperand(0);
chain.dequantOpd = currOp->getOperand(0);
// Bail out if any operand is per-channel quantized, which would
// require more complex fusion logic.
if (llvm::isa<Aten_MakePerChannelQuantizedTensorOp>(
chain.dequantOpd.getDefiningOp()))
break;
auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
scale = MPTQTOp.getOperand(1);
zeroPoint = MPTQTOp.getOperand(2);
chain.dequantOpd
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
chain.MPTQTOpd = MPTQTOp.getOperand(0);
chain.scale = MPTQTOp.getOperand(1);
chain.zeroPoint = MPTQTOp.getOperand(2);
}
// either a dequant was found or chain broken, so break loop
break;
}
// move to next operand if this trace was unsuccessful
if (!MPTQTOpd)
continue;
// if tracing this operand was successful, add it to operandChains.
if (chain.MPTQTOpd)
operandChains.push_back(std::move(chain));
}
// a successful trace occured, so set dequant to true
dequanted = true;
// Continuing the rewriting with only some of the operandsToQuantize traced
// successfully is possible but leads to "half-quantized" ops which are
// expected to cause problems in later lowering steps. We opt out of
// treating these cases for now.
if (operandChains.size() !=
std::size(QuantInfo<SrcOp>::operandsToQuantize)) {
if (!operandChains.empty())
op.emitWarning("Partially traced quantized operands. This op will "
"remain in QDQ form.");
return rewriter.notifyMatchFailure(
op, "did not find a complete quantized chain for all operands");
}
for (auto &&[i, chain] : llvm::enumerate(operandChains)) {
// rewrite stack
Value oldOpd = MPTQTOpd;
Value oldOpd = chain.MPTQTOpd;
Type intDType =
cast<ValueTensorType>(MPTQTOpd.getType()).getOptionalDtype();
while (!commutingOpStack.empty()) {
cast<ValueTensorType>(chain.MPTQTOpd.getType()).getOptionalDtype();
while (!chain.commutingOpStack.empty()) {
// get front of the commuting op stack and replace its first operand
// with oldOpd
auto currOp = commutingOpStack.top();
commutingOpStack.pop();
auto currOp = chain.commutingOpStack.top();
chain.commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// pad ops aren't quite commuting, so we include some extra logic to
@ -125,14 +147,15 @@ public:
Value floatPadValue = currOperands.back();
Value quantPadValue;
if (isa<Torch::NoneType>(floatPadValue.getType()))
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
quantPadValue =
rewriter.create<AtenFloatScalarOp>(loc, chain.zeroPoint);
else {
floatPadValue =
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
loc, floatPadValue, scale);
loc, floatPadValue, chain.scale);
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
loc, quantPadValue, zeroPoint);
loc, quantPadValue, chain.zeroPoint);
}
// clamp pad value to qint range
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
@ -175,19 +198,15 @@ public:
// stack is empty, so oldOpd is now the corrected verion of the
// SrcOp's original operand
// convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp
auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands();
auto MPTQTOperands = chain.dequantOpd.getDefiningOp()->getOperands();
auto qTorchType =
cast<ValueTensorType>(dequantOpd.getType()).getOptionalDtype();
cast<ValueTensorType>(chain.dequantOpd.getType()).getOptionalDtype();
auto newMPTQTType = rewriter.getType<ValueTensorType>(
cast<ValueTensorType>(operands[i].getType()).getSizes(), qTorchType);
operands[i] = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]);
}
if (!dequanted) {
return rewriter.notifyMatchFailure(op, "No dequantizations found.");
}
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}

View File

@ -59,10 +59,11 @@ public:
return success();
}
if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
auto clamp = rewriter.create<AtenClampOp>(
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
op.getOperand(3), op.getOperand(4));
auto prepareDequantize = [&](Value quantMin, Value quantMax, Value &clamp,
Type &qTy) {
clamp =
rewriter.create<AtenClampOp>(op.getLoc(), op.getOperand(0).getType(),
op.getOperand(0), quantMin, quantMax);
auto clampTy = cast<Torch::ValueTensorType>(clamp.getType());
if (!clampTy.hasDtype())
@ -75,8 +76,18 @@ public:
return rewriter.notifyMatchFailure(op,
"dequantization has unknown qtype");
Type qTy = Torch::ValueTensorType::get(
op.getContext(), clampTy.getOptionalSizes(), qetype);
qTy = Torch::ValueTensorType::get(op.getContext(),
clampTy.getOptionalSizes(), qetype);
return success();
};
if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
Value clamp;
Type qTy;
if (failed(prepareDequantize(op.getOperand(3), op.getOperand(4), clamp,
qTy)))
return failure();
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2));
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(
@ -84,6 +95,20 @@ public:
return success();
}
if (op.getName() == "torch.quantized_decomposed.dequantize_per_channel") {
Value clamp;
Type qTy;
if (failed(prepareDequantize(op.getOperand(4), op.getOperand(5), clamp,
qTy)))
return failure();
auto quant = rewriter.create<Aten_MakePerChannelQuantizedTensorOp>(
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2),
op.getOperand(3));
rewriter.replaceOpWithNewOp<AtenDequantizeSelfOp>(op, op.getResultTypes(),
quant);
return success();
}
return failure();
}
};

View File

@ -280,6 +280,9 @@ TORCHDYNAMO_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
@ -382,6 +385,9 @@ FX_IMPORTER_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
@ -550,6 +556,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
@ -2224,6 +2233,9 @@ LTC_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTranspose2DQInt8_basic",
}
@ -2374,6 +2386,9 @@ ONNX_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
@ -2953,6 +2968,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic",
@ -3748,6 +3766,9 @@ ONNX_TOSA_XFAIL_SET = {
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",

View File

@ -1309,6 +1309,96 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
)
class Conv2dQInt8PerChannelModuleBase(torch.nn.Module):
def __init__(self, groups=1):
self.groups = groups
super().__init__()
def _forward(self, inputVec, weight, scales, zeropoints, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec)
weight = torch._make_per_channel_quantized_tensor(
weight, scales, zeropoints, axis=0
)
weight = torch.dequantize(weight)
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
bias = torch.dequantize(bias)
return torch.ops.aten.conv2d(
inputVec,
weight,
bias=bias,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
groups=self.groups,
)
class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
([-1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, inputVec, weight, scales, zeropoints, bias):
return self._forward(inputVec, weight, scales, zeropoints, bias)
class Conv2dQInt8PerChannelModuleStatic(Conv2dQInt8PerChannelModuleBase):
@export
@annotate_args(
[
None,
([2, 3, 12, 12], torch.int8, True),
([3, 1, 5, 3], torch.int8, True),
([3], torch.float, True),
([3], torch.int8, True),
([3], torch.float, True),
]
)
def forward(self, inputVec, weight, scales, zeropoints, bias):
return self._forward(inputVec, weight, scales, zeropoints, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn())
def Conv2dQInt8PerChannelModule_basic(module, tu: TestUtils):
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
scales = tu.rand(3)
zeropoints = tu.rand(3).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, scales, zeropoints, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn(groups=2))
def Conv2dQInt8PerChannelModule_grouped(module, tu: TestUtils):
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
scales = tu.rand(6)
zeropoints = tu.rand(6).to(torch.int8)
bias = torch.rand(6)
module.forward(inputVec, weight, scales, zeropoints, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleStatic(groups=3))
def Conv2dQInt8PerChannelModule_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
scales = tu.rand(3)
zeropoints = tu.rand(3).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, scales, zeropoints, bias)
# torchvision.deform_conv2d
import torchvision

View File

@ -40,3 +40,24 @@ func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
return %13 : !torch.vtensor<[1,3,8,8],f32>
}
// -----
// CHECK-LABEL: func.func @dequantize_per_channel
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[32,3,8,8],si8>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[32],f32>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[32],si8>) -> !torch.vtensor<[32,3,8,8],f32> {
func.func @dequantize_per_channel(%arg0: !torch.vtensor<[32,3,8,8],si8>, %arg1: !torch.vtensor<[32],f32>, %arg2: !torch.vtensor<[32],si8>) -> !torch.vtensor<[32,3,8,8],f32> {
%min = torch.constant.int -128
%max = torch.constant.int 127
%dtype = torch.constant.int 1
%axis = torch.constant.int 0
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
// CHECK-DAG: %[[AXIS:.+]] = torch.constant.int 0
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[32,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[32,3,8,8],si8>
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_channel_quantized_tensor %[[CLAMP]], %[[ARG1]], %[[ARG2]], %[[AXIS]] : !torch.vtensor<[32,3,8,8],si8>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],si8>, !torch.int -> !torch.vtensor<[32,3,8,8],!torch.qint8>
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.self %[[QINT]] : !torch.vtensor<[32,3,8,8],!torch.qint8> -> !torch.vtensor<[32,3,8,8],f32>
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_channel"(%arg0, %arg1, %arg2, %axis, %min, %max, %dtype) : (!torch.vtensor<[32,3,8,8],si8>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],si8>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[32,3,8,8],f32>
return %13 : !torch.vtensor<[32,3,8,8],f32>
}