diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index fc910fa9d..a4962d12a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1184,10 +1184,6 @@ public: return success(); } - if (numSpatialDims != 2) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); - // Special depthwise case: Cin = Cout = groups. // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple // of groups) to be depthwise in their documentation, but the linalg ops @@ -1199,21 +1195,45 @@ public: if (inShape[1] == numGroups && weightShape[0] == numGroups && weightShape[1] == 1) { // Collapse weight shape (C/G == 1) - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{weightShape[0] * weightShape[1], - weightShape[2], weightShape[3]}; + SmallVector collapsedDims = {{0, 1}}; + SmallVector collapsedShape{weightShape[0] * weightShape[1]}; + for (unsigned i = 0; i < numSpatialDims; i++) { + collapsedDims.push_back({i + 2}); + collapsedShape.push_back(weightShape[i + 2]); + } Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); Value collapsedWeight = rewriter.create( loc, collapsedType, weight, collapsedDims); if (!inputZp) { - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D and 2D depthwise convolution " + "supported for special case of group convolution"); + }; } else { + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D depthwise quantized convolution " + "supported for special case of group convolution"); + // currently, the only named depthwise qconv op is nhwc_hwc // input: nchw -> nhwc; weight (collapsed): chw -> hwc // linalg conv result nhwc -> nchw @@ -1260,6 +1280,10 @@ public: return success(); } + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { auto inType = cast(tensor.getType()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 09db1098e..83c9ef855 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1048,6 +1048,7 @@ STABLEHLO_PASS_SET = { "ContainsIntList_False", "ContainsIntList_True", "ContiguousModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -3395,6 +3396,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -4087,6 +4089,7 @@ ONNX_TOSA_XFAIL_SET = { "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 4fe50243d..3bc176048 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.float32, True), + ([4, 1, 3], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv1d( + inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4 + ) + + +@register_test_case( + module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule() +) +def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(4, 1, 3) + module.forward(inputVec, weight) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__()