mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
This commit adds the support for the 1-d depthwise convolution as a special case of 1-d group convolution. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3776/head
parent
f6721e5999
commit
614fcdd153
|
@ -1184,10 +1184,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
// Special depthwise case: Cin = Cout = groups.
|
||||
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
||||
// of groups) to be depthwise in their documentation, but the linalg ops
|
||||
|
@ -1199,21 +1195,45 @@ public:
|
|||
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
||||
weightShape[1] == 1) {
|
||||
// Collapse weight shape (C/G == 1)
|
||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
||||
weightShape[2], weightShape[3]};
|
||||
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
|
||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
|
||||
for (unsigned i = 0; i < numSpatialDims; i++) {
|
||||
collapsedDims.push_back({i + 2});
|
||||
collapsedShape.push_back(weightShape[i + 2]);
|
||||
}
|
||||
Type collapsedType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, collapsedType, weight, collapsedDims);
|
||||
if (!inputZp) {
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
switch (numSpatialDims) {
|
||||
case 1:
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv1DNcwCwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
break;
|
||||
case 2:
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
break;
|
||||
default:
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 1D and 2D depthwise convolution "
|
||||
"supported for special case of group convolution");
|
||||
};
|
||||
} else {
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D depthwise quantized convolution "
|
||||
"supported for special case of group convolution");
|
||||
|
||||
// currently, the only named depthwise qconv op is nhwc_hwc
|
||||
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
||||
// linalg conv result nhwc -> nchw
|
||||
|
@ -1260,6 +1280,10 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
// Grouped case, use the grouped conv linalg op
|
||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||
|
|
|
@ -1048,6 +1048,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"ContiguousModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
@ -3395,6 +3396,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
|
@ -4087,6 +4089,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dBiasNoPaddingModule_basic",
|
||||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
|
|
|
@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
|
|||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 4, 6], torch.float32, True),
|
||||
([4, 1, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv1d(
|
||||
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
|
||||
)
|
||||
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||
inputVec = tu.rand(2, 4, 6)
|
||||
weight = torch.randn(4, 1, 3)
|
||||
module.forward(inputVec, weight)
|
||||
|
||||
|
||||
class Conv2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue