[TorchToLinalg] Use Op with native channel order for quantized conv2d (#3807)

I've upstreamed the necessary quantized linalg Op with the
"channel-first" ordering used by torch
(https://github.com/llvm/llvm-project/pull/107740) for 2d convolution.

This patch changes the lowering for the quantized 2d case of
`aten.convolution` accordingly, which saves three transpositions per
convolution (input, weights, result) and therefore removes the
requirement to try to optimize these away in downstream passes.
pull/3790/head
Felix Schneider 2024-10-22 20:26:16 +02:00 committed by GitHub
parent 42ba541c68
commit aca33f1742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 34 deletions

View File

@ -1125,6 +1125,16 @@ public:
} }
if (numGroups == 1 && inputZp) { if (numGroups == 1 && inputZp) {
switch (numSpatialDims) {
case 2:
conv = rewriter
.create<linalg::Conv2DNchwFchwQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
break;
case 3: {
// The quantized version uses a different channel ordering so we need to // The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should // permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering. // eventually directly support this channel ordering.
@ -1138,32 +1148,18 @@ public:
inPerms.push_back(1); inPerms.push_back(1);
weightPerms.append({1, 0}); weightPerms.append({1, 0});
paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); paddedInput =
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
outputTensor = outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
switch (numSpatialDims) {
case 2:
conv = rewriter
.create<linalg::Conv2DNhwcHwcfQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
conv = rewriter conv = rewriter
.create<linalg::Conv3DNdhwcDhwcfQOp>( .create<linalg::Conv3DNdhwcDhwcfQOp>(
loc, outputTensor.getType(), loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp}, ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr) outputTensor, stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
llvm::SmallVector<int64_t> outPerms; llvm::SmallVector<int64_t> outPerms;
outPerms.push_back(0); outPerms.push_back(0);
@ -1173,6 +1169,13 @@ public:
} }
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
break;
}
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultDTy) {
Type resultElementType = Type resultElementType =

View File

@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128]
// CHECK: %[[c7:.*]] = arith.constant 7 : i32 // CHECK: %[[c7:.*]] = arith.constant 7 : i32
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8> // CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8> // CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor<?x?x?x?xi8>) // CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: permutation = [0, 2, 3, 1] // CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor<?x?x?x?xi8>)
// CHECK-SAME: permutation = [2, 3, 1, 0]
// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
// CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> // CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%false = torch.constant.bool false %false = torch.constant.bool false