[TorchToLinalg] add support for quantized group conv (#3341)

This addresses 7 of the model failures I'm seeing in the test suite. See
[Shark-Turbine issue
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566).

Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream
before merging this. See [llvm-project PR #92136
](https://github.com/llvm/llvm-project/pull/92136).

A small additional expansion to operand quantization is included in this
patch to address a model failure that occurs when unblocking the
quantized group convolutions in one of these onnx models.
pull/3436/head
zjgarvey 2024-06-03 11:27:44 -05:00 committed by GitHub
parent 6382dbbcc0
commit 8995c90879
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 25 deletions

View File

@ -829,7 +829,7 @@ public:
op, "lhs and rhs of convolution must either be both int or fp"); op, "lhs and rhs of convolution must either be both int or fp");
} }
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) { if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType(); auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
if (!biasDTy.isInteger(32)) { if (!biasDTy.isInteger(32)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1123,7 +1123,7 @@ public:
// - grouped 1d-3d // - grouped 1d-3d
// - grouped 1d-3d (quantized) // - grouped 1d-3d (quantized)
// - ungrouped 1d-3d // - ungrouped 1d-3d
if (groupSize == 1 && !inputZp && !weightZp) { if (groupSize == 1 && !inputZp) {
switch (numSpatialDims) { switch (numSpatialDims) {
case 1: case 1:
conv = rewriter conv = rewriter
@ -1164,7 +1164,7 @@ public:
return success(); return success();
} }
if (groupSize == 1 && inputZp && weightZp) { if (groupSize == 1 && inputZp) {
// The quantized version uses a different channel ordering so we need to // The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should // permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering. // eventually directly support this channel ordering.
@ -1224,10 +1224,6 @@ public:
return success(); return success();
} }
if (inputZp || weightZp)
return rewriter.notifyMatchFailure(
op, "unimplemented: quantized grouped convolutions");
if (numSpatialDims != 2) if (numSpatialDims != 2)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported"); op, "unimplemented: only 2D grouped convolution supported");
@ -1238,7 +1234,7 @@ public:
auto weightShape = makeShapeTorchCompatible( auto weightShape = makeShapeTorchCompatible(
cast<RankedTensorType>(weight.getType()).getShape()); cast<RankedTensorType>(weight.getType()).getShape());
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
// Collapse weight shape // Collapse weight shape
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}}; SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{ SmallVector<int64_t> collapsedShape{
@ -1325,13 +1321,22 @@ public:
auto expandOutputTensor = expandGroups(outputTensor, 1); auto expandOutputTensor = expandGroups(outputTensor, 1);
// TODO: add 1D and 3D case // TODO: add 1D and 3D case
if (!inputZp) {
conv = rewriter conv = rewriter
.create<linalg::Conv2DNgchwGfchwOp>( .create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(), loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded}, ValueRange{paddedInputExpanded, weightExpanded},
expandOutputTensor.getResult(), stridesAttr, dilationAttr) expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
} else {
conv = rewriter
.create<linalg::Conv2DNgchwGfchwQOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
weightZp},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
}
conv = rewriter.create<tensor::CollapseShapeOp>( conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices()); expandOutputTensor.getReassociationIndices());

View File

@ -378,7 +378,7 @@ public:
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>, QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>, QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>, QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 2>, QuantizeOperandsPastCommutingOps<AtenMmOp, 4>,
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>, QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>( QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context); context);

View File

@ -277,6 +277,7 @@ TORCHDYNAMO_XFAIL_SET = {
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc # Dynamo not supporting conv_tbc
"ConvTbcModule_basic", "ConvTbcModule_basic",
@ -373,6 +374,7 @@ FX_IMPORTER_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
@ -543,6 +545,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
@ -2147,6 +2150,7 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
} }
@ -2298,6 +2302,7 @@ ONNX_XFAIL_SET = {
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic", "Conv2dWithPaddingModule_basic",
"Conv3dModule_basic", "Conv3dModule_basic",
@ -2851,6 +2856,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic", "Conv3dModule_basic",
@ -3637,6 +3643,7 @@ ONNX_TOSA_XFAIL_SET = {
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",

View File

@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils):
class Conv2dQInt8Module(torch.nn.Module): class Conv2dQInt8Module(torch.nn.Module):
def __init__(self): def __init__(self, groups=1):
self.groups = groups
super().__init__() super().__init__()
@export @export
@ -1186,7 +1187,7 @@ class Conv2dQInt8Module(torch.nn.Module):
stride=[1, 1], stride=[1, 1],
padding=[0, 0], padding=[0, 0],
dilation=[1, 1], dilation=[1, 1],
groups=1, groups=self.groups,
) )
@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
N = 10 @register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
Cin = 5 def Conv2dQInt8Module_grouped(module, tu: TestUtils):
Cout = 7 inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
Hin = 10 weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
Win = 8 bias = torch.rand(6)
Hker = 3 module.forward(inputVec, weight, bias)
Wker = 2
class ConvTranspose2DQInt8Module(torch.nn.Module): class ConvTranspose2DQInt8Module(torch.nn.Module):
@ -1244,6 +1244,13 @@ class ConvTranspose2DQInt8Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) @register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils): def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2
module.forward( module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),