[TorchToLinalg] add support for depthwise qconv (#3564)

- Adds support for lowering depthwise + quantized convolution ops to
linalg::DepthwiseConv2DNhwcHwcQOp
- Changed the variable name for groupSize (which is really C/G) to the
more appropriate numGroups (G).
- Discovered in e2e testing that linalg does not accept (Cin = groups &&
Cout = K*groups for K>1) as a "depthwise" conv, so this also updates the
case-checking to reflect this issue.
pull/3572/head
zjgarvey 2024-07-29 12:25:07 -07:00 committed by GitHub
parent 50d6ce225f
commit f1c74e1431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 45 deletions

View File

@ -788,7 +788,7 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value input = adaptor.getInput(); /* in form of N*C*H*W */
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C/G*H*W */
Value bias = adaptor.getBias(); Value bias = adaptor.getBias();
auto resultTy = cast<ValueTensorType>(op.getType()); auto resultTy = cast<ValueTensorType>(op.getType());
@ -898,8 +898,8 @@ public:
weightDims.push_back(getDimOp(rewriter, loc, weight, i)); weightDims.push_back(getDimOp(rewriter, loc, weight, i));
// Checks for valid group size // Checks for valid group size
int64_t groupSize; int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groupSize))) if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only constant group size supported."); "only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
@ -1118,14 +1118,14 @@ public:
Value conv; Value conv;
// the code so far is able to respect all numSpatialDims // the code so far is able to respect all numSpatialDims
// the code below this point is numSpatialDims specific and groupSize // the code below this point is numSpatialDims specific and numGroups
// specific // specific
// TODO: factor out the above code into a helper function, and then separate // TODO: factor out the above code into a helper function, and then separate
// convolution into: // convolution into:
// - grouped 1d-3d // - grouped 1d-3d
// - grouped 1d-3d (quantized) // - grouped 1d-3d (quantized)
// - ungrouped 1d-3d // - ungrouped 1d-3d
if (groupSize == 1 && !inputZp) { if (numGroups == 1 && !inputZp) {
switch (numSpatialDims) { switch (numSpatialDims) {
case 1: case 1:
conv = rewriter conv = rewriter
@ -1166,7 +1166,7 @@ public:
return success(); return success();
} }
if (groupSize == 1 && inputZp) { if (numGroups == 1 && inputZp) {
// The quantized version uses a different channel ordering so we need to // The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should // permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering. // eventually directly support this channel ordering.
@ -1230,30 +1230,66 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported"); op, "unimplemented: only 2D grouped convolution supported");
// Special depthwise case // Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops
// apparently disagree.
auto inShape = makeShapeTorchCompatible( auto inShape = makeShapeTorchCompatible(
cast<RankedTensorType>(input.getType()).getShape()); cast<RankedTensorType>(input.getType()).getShape());
auto weightShape = makeShapeTorchCompatible( auto weightShape = makeShapeTorchCompatible(
cast<RankedTensorType>(weight.getType()).getShape()); cast<RankedTensorType>(weight.getType()).getShape());
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) { weightShape[1] == 1) {
// Collapse weight shape // Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}}; SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{ SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
(weightShape[0] == kUnknownSize ? kUnknownSize
: weightShape[0] * weightShape[1]),
weightShape[2], weightShape[3]}; weightShape[2], weightShape[3]};
Type collapsedType = RankedTensorType::get( Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy); makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>( Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims); loc, collapsedType, weight, collapsedDims);
if (!inputZp) {
conv = rewriter conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>( .create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(), loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor, ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr) stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
} else {
// currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw
// inPerms = [0, 2, 3, 1]
// weightPerms = [1, 2, 0]
// resultPerms = [0, 3, 1, 2]
llvm::SmallVector<int64_t> inPerms, weightPerms, resultPerms;
inPerms.push_back(0);
resultPerms.append({0, static_cast<int64_t>(numSpatialDims + 1)});
for (size_t i = 0; i < numSpatialDims; ++i) {
inPerms.push_back(i + 2);
weightPerms.push_back(i + 1);
resultPerms.push_back(i + 1);
}
inPerms.push_back(1);
weightPerms.push_back(0);
paddedInput =
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
collapsedWeight =
transposeValue(op.getLoc(), collapsedWeight, weightPerms, rewriter);
outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
// convert output nhwc -> nchw
conv = transposeValue(op.getLoc(), conv, resultPerms, rewriter);
}
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
@ -1274,12 +1310,12 @@ public:
SmallVector<int64_t> outShape; SmallVector<int64_t> outShape;
for (auto i = 0; i < (long)inShape.size(); i++) { for (auto i = 0; i < (long)inShape.size(); i++) {
if (i == 1) { if (i == 1) {
outShape.push_back(groupSize); outShape.push_back(numGroups);
} }
if (i == (long)dim) { if (i == (long)dim) {
outShape.push_back(inShape[i] == kUnknownSize outShape.push_back(inShape[i] == kUnknownSize
? kUnknownSize ? kUnknownSize
: inShape[i] / groupSize); : inShape[i] / numGroups);
} else { } else {
outShape.push_back(inShape[i]); outShape.push_back(inShape[i]);
} }
@ -1305,8 +1341,8 @@ public:
auto inShape = makeShapeTorchCompatible(inType.getShape()); auto inShape = makeShapeTorchCompatible(inType.getShape());
SmallVector<int64_t> outShape{ SmallVector<int64_t> outShape{
groupSize, numGroups,
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)}; (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
outShape.append(inShape.begin() + 1, inShape.end()); outShape.append(inShape.begin() + 1, inShape.end());
SmallVector<ReassociationIndices> indices{{0, 1}}; SmallVector<ReassociationIndices> indices{{0, 1}};

View File

@ -16,9 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic", "IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic", "IscloseStaticModuleTrue_basic",
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
@ -250,9 +247,6 @@ TORCHDYNAMO_XFAIL_SET = {
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put # AssertionError: Unregistered operation: torch.aten._unsafe_index_put
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagStaticModule_basic",
# Lowering not present for this case # Lowering not present for this case
@ -281,7 +275,9 @@ TORCHDYNAMO_XFAIL_SET = {
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc # Dynamo not supporting conv_tbc
"ConvTbcModule_basic", "ConvTbcModule_basic",
@ -380,8 +376,9 @@ FX_IMPORTER_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dQInt8Module_not_depthwise",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
@ -547,7 +544,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
@ -2204,7 +2203,9 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"ConvTranspose2DQInt8_basic", "ConvTranspose2DQInt8_basic",
} }
@ -2350,7 +2351,9 @@ ONNX_XFAIL_SET = {
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic", "Conv2dWithPaddingModule_basic",
"Conv3dModule_basic", "Conv3dModule_basic",
@ -2718,7 +2721,6 @@ ONNX_XFAIL_SET = {
"BernoulliModule_basic", "BernoulliModule_basic",
"Conv_Transpose1dModule_basic", "Conv_Transpose1dModule_basic",
"Conv_Transpose3dModule_basic", "Conv_Transpose3dModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
@ -2922,7 +2924,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic", "Conv3dModule_basic",
@ -3715,7 +3719,9 @@ ONNX_TOSA_XFAIL_SET = {
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",

View File

@ -1156,21 +1156,12 @@ def ConvTbcModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
class Conv2dQInt8Module(torch.nn.Module): class Conv2dQInt8ModuleBase(torch.nn.Module):
def __init__(self, groups=1): def __init__(self, groups=1):
self.groups = groups self.groups = groups
super().__init__() super().__init__()
@export def _forward(self, inputVec, weight, bias):
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, inputVec, weight, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7) inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec) inputVec = torch.dequantize(inputVec)
@ -1191,7 +1182,49 @@ class Conv2dQInt8Module(torch.nn.Module):
) )
@register_test_case(module_factory=lambda: Conv2dQInt8Module()) class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, inputVec, weight, bias):
return self._forward(inputVec, weight, bias)
class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
@export
@annotate_args(
[
None,
([2, 3, 12, 12], torch.int8, True),
([3, 1, 5, 3], torch.int8, True),
([3], torch.float, True),
]
)
def forward(self, inputVec, weight, bias):
return self._forward(inputVec, weight, bias)
class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
@export
@annotate_args(
[
None,
([2, 3, 12, 12], torch.int8, True),
([6, 1, 5, 3], torch.int8, True),
([6], torch.float, True),
]
)
def forward(self, inputVec, weight, bias):
return self._forward(inputVec, weight, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn())
def Conv2dQInt8Module_basic(module, tu: TestUtils): def Conv2dQInt8Module_basic(module, tu: TestUtils):
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
@ -1199,7 +1232,7 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2)) @register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2))
def Conv2dQInt8Module_grouped(module, tu: TestUtils): def Conv2dQInt8Module_grouped(module, tu: TestUtils):
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8) inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8) weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
@ -1207,6 +1240,24 @@ def Conv2dQInt8Module_grouped(module, tu: TestUtils):
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3))
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, bias)
@register_test_case(
module_factory=lambda: Conv2dQInt8ModuleStatic_MoreOutChannels(groups=3)
)
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
bias = torch.rand(6)
module.forward(inputVec, weight, bias)
class ConvTranspose2DQInt8Module(torch.nn.Module): class ConvTranspose2DQInt8Module(torch.nn.Module):
def __init__(self): def __init__(self):