diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12a..9c914690b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1125,54 +1125,57 @@ public: } if (numGroups == 1 && inputZp) { - // The quantized version uses a different channel ordering so we need to - // permute the tensors in order to use the existing path. We should - // eventually directly support this channel ordering. - llvm::SmallVector inPerms, weightPerms; - inPerms.push_back(0); // N stays at the front for input. - // Then we expect the spatial dimensions - for (size_t i = 0; i < numSpatialDims; ++i) { - inPerms.push_back(i + 2); - weightPerms.push_back(i + 2); - } - inPerms.push_back(1); - weightPerms.append({1, 0}); - - paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); - weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); - outputTensor = - transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - switch (numSpatialDims) { case 2: conv = rewriter - .create( + .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; - case 3: + case 3: { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); + } + inPerms.push_back(1); + weightPerms.append({1, 0}); + + paddedInput = + transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); + + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); + } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); + break; + } default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; - llvm::SmallVector outPerms; - outPerms.push_back(0); - outPerms.push_back(inPerms.size() - 1); - for (size_t i = 0; i < numSpatialDims; ++i) { - outPerms.push_back(i + 1); - } - conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 3023c0ba6..480b1eeb9 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128] // CHECK: %[[c7:.*]] = arith.constant 7 : i32 // CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor // CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor -// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor) -// CHECK-SAME: permutation = [0, 2, 3, 1] -// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor) -// CHECK-SAME: permutation = [2, 3, 1, 0] -// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} -// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) +// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) // CHECK-SAME: outs(%[[convout:.*]] : tensor) -> tensor func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool false