mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] add support for quantized group conv (#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See [Shark-Turbine issue #566](https://github.com/nod-ai/SHARK-Turbine/issues/566). Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream before merging this. See [llvm-project PR #92136 ](https://github.com/llvm/llvm-project/pull/92136). A small additional expansion to operand quantization is included in this patch to address a model failure that occurs when unblocking the quantized group convolutions in one of these onnx models.pull/3436/head
parent
6382dbbcc0
commit
8995c90879
|
@ -829,7 +829,7 @@ public:
|
|||
op, "lhs and rhs of convolution must either be both int or fp");
|
||||
}
|
||||
|
||||
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||
if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
||||
if (!biasDTy.isInteger(32)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1123,7 +1123,7 @@ public:
|
|||
// - grouped 1d-3d
|
||||
// - grouped 1d-3d (quantized)
|
||||
// - ungrouped 1d-3d
|
||||
if (groupSize == 1 && !inputZp && !weightZp) {
|
||||
if (groupSize == 1 && !inputZp) {
|
||||
switch (numSpatialDims) {
|
||||
case 1:
|
||||
conv = rewriter
|
||||
|
@ -1164,7 +1164,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (groupSize == 1 && inputZp && weightZp) {
|
||||
if (groupSize == 1 && inputZp) {
|
||||
// The quantized version uses a different channel ordering so we need to
|
||||
// permute the tensors in order to use the existing path. We should
|
||||
// eventually directly support this channel ordering.
|
||||
|
@ -1224,10 +1224,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (inputZp || weightZp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: quantized grouped convolutions");
|
||||
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
@ -1238,7 +1234,7 @@ public:
|
|||
auto weightShape = makeShapeTorchCompatible(
|
||||
cast<RankedTensorType>(weight.getType()).getShape());
|
||||
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
|
||||
// Collapse weight shape
|
||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
||||
SmallVector<int64_t> collapsedShape{
|
||||
|
@ -1325,13 +1321,22 @@ public:
|
|||
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
||||
|
||||
// TODO: add 1D and 3D case
|
||||
conv = rewriter
|
||||
.create<linalg::Conv2DNgchwGfchwOp>(
|
||||
loc, expandOutputTensor.getResultType(),
|
||||
ValueRange{paddedInputExpanded, weightExpanded},
|
||||
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
|
||||
if (!inputZp) {
|
||||
conv = rewriter
|
||||
.create<linalg::Conv2DNgchwGfchwOp>(
|
||||
loc, expandOutputTensor.getResultType(),
|
||||
ValueRange{paddedInputExpanded, weightExpanded},
|
||||
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
} else {
|
||||
conv = rewriter
|
||||
.create<linalg::Conv2DNgchwGfchwQOp>(
|
||||
loc, expandOutputTensor.getResultType(),
|
||||
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
|
||||
weightZp},
|
||||
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
}
|
||||
conv = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, outputTensor.getType(), conv,
|
||||
expandOutputTensor.getReassociationIndices());
|
||||
|
|
|
@ -378,7 +378,7 @@ public:
|
|||
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
||||
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
||||
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
||||
QuantizeOperandsPastCommutingOps<AtenMmOp, 2>,
|
||||
QuantizeOperandsPastCommutingOps<AtenMmOp, 4>,
|
||||
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
|
||||
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||
context);
|
||||
|
|
|
@ -277,6 +277,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
# Dynamo not supporting conv_tbc
|
||||
"ConvTbcModule_basic",
|
||||
|
@ -373,6 +374,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
|
@ -543,6 +545,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
|
@ -2147,6 +2150,7 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
}
|
||||
|
||||
|
@ -2298,6 +2302,7 @@ ONNX_XFAIL_SET = {
|
|||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingModule_basic",
|
||||
"Conv3dModule_basic",
|
||||
|
@ -2851,6 +2856,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
"Conv3dModule_basic",
|
||||
|
@ -3637,6 +3643,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
|
|
|
@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Conv2dQInt8Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, groups=1):
|
||||
self.groups = groups
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
|
@ -1186,7 +1187,7 @@ class Conv2dQInt8Module(torch.nn.Module):
|
|||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
groups=1,
|
||||
groups=self.groups,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
|||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
N = 10
|
||||
Cin = 5
|
||||
Cout = 7
|
||||
Hin = 10
|
||||
Win = 8
|
||||
Hker = 3
|
||||
Wker = 2
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
|
||||
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
||||
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
|
||||
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||
bias = torch.rand(6)
|
||||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||
|
@ -1244,6 +1244,13 @@ class ConvTranspose2DQInt8Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
||||
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
||||
N = 10
|
||||
Cin = 5
|
||||
Cout = 7
|
||||
Hin = 10
|
||||
Win = 8
|
||||
Hker = 3
|
||||
Wker = 2
|
||||
module.forward(
|
||||
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
||||
|
|
Loading…
Reference in New Issue