mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
This commit adds the support for the 1-d depthwise convolution as a special case of 1-d group convolution. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3776/head
parent
f6721e5999
commit
614fcdd153
|
@ -1184,10 +1184,6 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numSpatialDims != 2)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only 2D grouped convolution supported");
|
|
||||||
|
|
||||||
// Special depthwise case: Cin = Cout = groups.
|
// Special depthwise case: Cin = Cout = groups.
|
||||||
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
||||||
// of groups) to be depthwise in their documentation, but the linalg ops
|
// of groups) to be depthwise in their documentation, but the linalg ops
|
||||||
|
@ -1199,21 +1195,45 @@ public:
|
||||||
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
||||||
weightShape[1] == 1) {
|
weightShape[1] == 1) {
|
||||||
// Collapse weight shape (C/G == 1)
|
// Collapse weight shape (C/G == 1)
|
||||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
|
||||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
|
||||||
weightShape[2], weightShape[3]};
|
for (unsigned i = 0; i < numSpatialDims; i++) {
|
||||||
|
collapsedDims.push_back({i + 2});
|
||||||
|
collapsedShape.push_back(weightShape[i + 2]);
|
||||||
|
}
|
||||||
Type collapsedType = RankedTensorType::get(
|
Type collapsedType = RankedTensorType::get(
|
||||||
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
||||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
loc, collapsedType, weight, collapsedDims);
|
loc, collapsedType, weight, collapsedDims);
|
||||||
if (!inputZp) {
|
if (!inputZp) {
|
||||||
conv = rewriter
|
switch (numSpatialDims) {
|
||||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
case 1:
|
||||||
loc, outputTensor.getType(),
|
conv = rewriter
|
||||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
.create<linalg::DepthwiseConv1DNcwCwOp>(
|
||||||
stridesAttr, dilationAttr)
|
loc, outputTensor.getType(),
|
||||||
.getResult(0);
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||||
|
stridesAttr, dilationAttr)
|
||||||
|
.getResult(0);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
conv = rewriter
|
||||||
|
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||||
|
loc, outputTensor.getType(),
|
||||||
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||||
|
stridesAttr, dilationAttr)
|
||||||
|
.getResult(0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 1D and 2D depthwise convolution "
|
||||||
|
"supported for special case of group convolution");
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
|
if (numSpatialDims != 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 2D depthwise quantized convolution "
|
||||||
|
"supported for special case of group convolution");
|
||||||
|
|
||||||
// currently, the only named depthwise qconv op is nhwc_hwc
|
// currently, the only named depthwise qconv op is nhwc_hwc
|
||||||
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
||||||
// linalg conv result nhwc -> nchw
|
// linalg conv result nhwc -> nchw
|
||||||
|
@ -1260,6 +1280,10 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (numSpatialDims != 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 2D grouped convolution supported");
|
||||||
|
|
||||||
// Grouped case, use the grouped conv linalg op
|
// Grouped case, use the grouped conv linalg op
|
||||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||||
|
|
|
@ -1048,6 +1048,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"ContiguousModule_basic",
|
"ContiguousModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
@ -3395,6 +3396,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
"Conv2dQInt8Module_depthwise",
|
"Conv2dQInt8Module_depthwise",
|
||||||
"Conv2dQInt8Module_grouped",
|
"Conv2dQInt8Module_grouped",
|
||||||
|
@ -4087,6 +4089,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dBiasNoPaddingModule_basic",
|
"Conv2dBiasNoPaddingModule_basic",
|
||||||
"Conv2dModule_basic",
|
"Conv2dModule_basic",
|
||||||
"Conv2dNoPaddingModule_basic",
|
"Conv2dNoPaddingModule_basic",
|
||||||
|
|
|
@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 4, 6], torch.float32, True),
|
||||||
|
([4, 1, 3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten.conv1d(
|
||||||
|
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
|
||||||
|
)
|
||||||
|
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||||
|
inputVec = tu.rand(2, 4, 6)
|
||||||
|
weight = torch.randn(4, 1, 3)
|
||||||
|
module.forward(inputVec, weight)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dModule(torch.nn.Module):
|
class Conv2dModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue