[MLIR][TORCH] Add support for 1-d group convolution (#3770)

This commit adds the support for the 1-d depthwise convolution as a
special case of 1-d group convolution.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3776/head
Vivek Khandelwal 2024-10-08 10:48:47 +05:30 committed by GitHub
parent f6721e5999
commit 614fcdd153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 13 deletions

View File

@ -1184,10 +1184,6 @@ public:
return success(); return success();
} }
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
// Special depthwise case: Cin = Cout = groups. // Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops // of groups) to be depthwise in their documentation, but the linalg ops
@ -1199,21 +1195,45 @@ public:
if (inShape[1] == numGroups && weightShape[0] == numGroups && if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) { weightShape[1] == 1) {
// Collapse weight shape (C/G == 1) // Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}}; SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1], SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
weightShape[2], weightShape[3]}; for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get( Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy); makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>( Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims); loc, collapsedType, weight, collapsedDims);
if (!inputZp) { if (!inputZp) {
conv = rewriter switch (numSpatialDims) {
.create<linalg::DepthwiseConv2DNchwChwOp>( case 1:
loc, outputTensor.getType(), conv = rewriter
ValueRange{paddedInput, collapsedWeight}, outputTensor, .create<linalg::DepthwiseConv1DNcwCwOp>(
stridesAttr, dilationAttr) loc, outputTensor.getType(),
.getResult(0); ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else { } else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");
// currently, the only named depthwise qconv op is nhwc_hwc // currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc // input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw // linalg conv result nhwc -> nchw
@ -1260,6 +1280,10 @@ public:
return success(); return success();
} }
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
// Grouped case, use the grouped conv linalg op // Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) { auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType()); auto inType = cast<RankedTensorType>(tensor.getType());

View File

@ -1048,6 +1048,7 @@ STABLEHLO_PASS_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"ContiguousModule_basic", "ContiguousModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
@ -3395,6 +3396,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
@ -4087,6 +4089,7 @@ ONNX_TOSA_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dBiasNoPaddingModule_basic", "Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",

View File

@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 4, 6], torch.float32, True),
([4, 1, 3], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv1d(
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
)
@register_test_case(
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
)
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
inputVec = tu.rand(2, 4, 6)
weight = torch.randn(4, 1, 3)
module.forward(inputVec, weight)
class Conv2dModule(torch.nn.Module): class Conv2dModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()