mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] add support for quantized group conv (#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See [Shark-Turbine issue #566](https://github.com/nod-ai/SHARK-Turbine/issues/566). Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream before merging this. See [llvm-project PR #92136 ](https://github.com/llvm/llvm-project/pull/92136). A small additional expansion to operand quantization is included in this patch to address a model failure that occurs when unblocking the quantized group convolutions in one of these onnx models.pull/3436/head
parent
6382dbbcc0
commit
8995c90879
|
@ -829,7 +829,7 @@ public:
|
||||||
op, "lhs and rhs of convolution must either be both int or fp");
|
op, "lhs and rhs of convolution must either be both int or fp");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||||
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
||||||
if (!biasDTy.isInteger(32)) {
|
if (!biasDTy.isInteger(32)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1123,7 +1123,7 @@ public:
|
||||||
// - grouped 1d-3d
|
// - grouped 1d-3d
|
||||||
// - grouped 1d-3d (quantized)
|
// - grouped 1d-3d (quantized)
|
||||||
// - ungrouped 1d-3d
|
// - ungrouped 1d-3d
|
||||||
if (groupSize == 1 && !inputZp && !weightZp) {
|
if (groupSize == 1 && !inputZp) {
|
||||||
switch (numSpatialDims) {
|
switch (numSpatialDims) {
|
||||||
case 1:
|
case 1:
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
|
@ -1164,7 +1164,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (groupSize == 1 && inputZp && weightZp) {
|
if (groupSize == 1 && inputZp) {
|
||||||
// The quantized version uses a different channel ordering so we need to
|
// The quantized version uses a different channel ordering so we need to
|
||||||
// permute the tensors in order to use the existing path. We should
|
// permute the tensors in order to use the existing path. We should
|
||||||
// eventually directly support this channel ordering.
|
// eventually directly support this channel ordering.
|
||||||
|
@ -1224,10 +1224,6 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inputZp || weightZp)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: quantized grouped convolutions");
|
|
||||||
|
|
||||||
if (numSpatialDims != 2)
|
if (numSpatialDims != 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only 2D grouped convolution supported");
|
op, "unimplemented: only 2D grouped convolution supported");
|
||||||
|
@ -1238,7 +1234,7 @@ public:
|
||||||
auto weightShape = makeShapeTorchCompatible(
|
auto weightShape = makeShapeTorchCompatible(
|
||||||
cast<RankedTensorType>(weight.getType()).getShape());
|
cast<RankedTensorType>(weight.getType()).getShape());
|
||||||
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
||||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
|
||||||
// Collapse weight shape
|
// Collapse weight shape
|
||||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
||||||
SmallVector<int64_t> collapsedShape{
|
SmallVector<int64_t> collapsedShape{
|
||||||
|
@ -1325,13 +1321,22 @@ public:
|
||||||
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
||||||
|
|
||||||
// TODO: add 1D and 3D case
|
// TODO: add 1D and 3D case
|
||||||
|
if (!inputZp) {
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::Conv2DNgchwGfchwOp>(
|
.create<linalg::Conv2DNgchwGfchwOp>(
|
||||||
loc, expandOutputTensor.getResultType(),
|
loc, expandOutputTensor.getResultType(),
|
||||||
ValueRange{paddedInputExpanded, weightExpanded},
|
ValueRange{paddedInputExpanded, weightExpanded},
|
||||||
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
} else {
|
||||||
|
conv = rewriter
|
||||||
|
.create<linalg::Conv2DNgchwGfchwQOp>(
|
||||||
|
loc, expandOutputTensor.getResultType(),
|
||||||
|
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
|
||||||
|
weightZp},
|
||||||
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
conv = rewriter.create<tensor::CollapseShapeOp>(
|
conv = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
loc, outputTensor.getType(), conv,
|
loc, outputTensor.getType(), conv,
|
||||||
expandOutputTensor.getReassociationIndices());
|
expandOutputTensor.getReassociationIndices());
|
||||||
|
|
|
@ -378,7 +378,7 @@ public:
|
||||||
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenMmOp, 2>,
|
QuantizeOperandsPastCommutingOps<AtenMmOp, 4>,
|
||||||
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
|
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
|
||||||
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
|
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||||
context);
|
context);
|
||||||
|
|
|
@ -277,6 +277,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"ConvTranspose2DQInt8_basic",
|
"ConvTranspose2DQInt8_basic",
|
||||||
# Dynamo not supporting conv_tbc
|
# Dynamo not supporting conv_tbc
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
@ -373,6 +374,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
"ConvTranspose2DQInt8_basic",
|
"ConvTranspose2DQInt8_basic",
|
||||||
|
@ -543,6 +545,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
"ConvTranspose2DQInt8_basic",
|
"ConvTranspose2DQInt8_basic",
|
||||||
"ConvolutionBackwardModule2DPadded_basic",
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
|
@ -2147,6 +2150,7 @@ LTC_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"ConvTranspose2DQInt8_basic",
|
"ConvTranspose2DQInt8_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2298,6 +2302,7 @@ ONNX_XFAIL_SET = {
|
||||||
"Conv2dModule_basic",
|
"Conv2dModule_basic",
|
||||||
"Conv2dNoPaddingModule_basic",
|
"Conv2dNoPaddingModule_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||||
"Conv2dWithPaddingModule_basic",
|
"Conv2dWithPaddingModule_basic",
|
||||||
"Conv3dModule_basic",
|
"Conv3dModule_basic",
|
||||||
|
@ -2851,6 +2856,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||||
"Conv3dModule_basic",
|
"Conv3dModule_basic",
|
||||||
|
@ -3637,6 +3643,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"Conv2dModule_basic",
|
"Conv2dModule_basic",
|
||||||
"Conv2dNoPaddingModule_basic",
|
"Conv2dNoPaddingModule_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dQInt8Module_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||||
|
|
|
@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
class Conv2dQInt8Module(torch.nn.Module):
|
class Conv2dQInt8Module(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, groups=1):
|
||||||
|
self.groups = groups
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
|
@ -1186,7 +1187,7 @@ class Conv2dQInt8Module(torch.nn.Module):
|
||||||
stride=[1, 1],
|
stride=[1, 1],
|
||||||
padding=[0, 0],
|
padding=[0, 0],
|
||||||
dilation=[1, 1],
|
dilation=[1, 1],
|
||||||
groups=1,
|
groups=self.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
N = 10
|
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
|
||||||
Cin = 5
|
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
||||||
Cout = 7
|
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
|
||||||
Hin = 10
|
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||||
Win = 8
|
bias = torch.rand(6)
|
||||||
Hker = 3
|
module.forward(inputVec, weight, bias)
|
||||||
Wker = 2
|
|
||||||
|
|
||||||
|
|
||||||
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||||
|
@ -1244,6 +1244,13 @@ class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
||||||
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
||||||
|
N = 10
|
||||||
|
Cin = 5
|
||||||
|
Cout = 7
|
||||||
|
Hin = 10
|
||||||
|
Win = 8
|
||||||
|
Hker = 3
|
||||||
|
Wker = 2
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
|
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
|
||||||
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
||||||
|
|
Loading…
Reference in New Issue