mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Use Op with native channel order for quantized conv2d (#3807)
I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (https://github.com/llvm/llvm-project/pull/107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes.pull/3790/head
parent
42ba541c68
commit
aca33f1742
|
@ -1125,54 +1125,57 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numGroups == 1 && inputZp) {
|
if (numGroups == 1 && inputZp) {
|
||||||
// The quantized version uses a different channel ordering so we need to
|
|
||||||
// permute the tensors in order to use the existing path. We should
|
|
||||||
// eventually directly support this channel ordering.
|
|
||||||
llvm::SmallVector<int64_t> inPerms, weightPerms;
|
|
||||||
inPerms.push_back(0); // N stays at the front for input.
|
|
||||||
// Then we expect the spatial dimensions
|
|
||||||
for (size_t i = 0; i < numSpatialDims; ++i) {
|
|
||||||
inPerms.push_back(i + 2);
|
|
||||||
weightPerms.push_back(i + 2);
|
|
||||||
}
|
|
||||||
inPerms.push_back(1);
|
|
||||||
weightPerms.append({1, 0});
|
|
||||||
|
|
||||||
paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
|
|
||||||
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
|
|
||||||
outputTensor =
|
|
||||||
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
|
|
||||||
|
|
||||||
switch (numSpatialDims) {
|
switch (numSpatialDims) {
|
||||||
case 2:
|
case 2:
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::Conv2DNhwcHwcfQOp>(
|
.create<linalg::Conv2DNchwFchwQOp>(
|
||||||
loc, outputTensor.getType(),
|
loc, outputTensor.getType(),
|
||||||
ValueRange{paddedInput, weight, inputZp, weightZp},
|
ValueRange{paddedInput, weight, inputZp, weightZp},
|
||||||
outputTensor, stridesAttr, dilationAttr)
|
outputTensor, stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3: {
|
||||||
|
// The quantized version uses a different channel ordering so we need to
|
||||||
|
// permute the tensors in order to use the existing path. We should
|
||||||
|
// eventually directly support this channel ordering.
|
||||||
|
llvm::SmallVector<int64_t> inPerms, weightPerms;
|
||||||
|
inPerms.push_back(0); // N stays at the front for input.
|
||||||
|
// Then we expect the spatial dimensions
|
||||||
|
for (size_t i = 0; i < numSpatialDims; ++i) {
|
||||||
|
inPerms.push_back(i + 2);
|
||||||
|
weightPerms.push_back(i + 2);
|
||||||
|
}
|
||||||
|
inPerms.push_back(1);
|
||||||
|
weightPerms.append({1, 0});
|
||||||
|
|
||||||
|
paddedInput =
|
||||||
|
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
|
||||||
|
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
|
||||||
|
outputTensor =
|
||||||
|
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
|
||||||
|
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::Conv3DNdhwcDhwcfQOp>(
|
.create<linalg::Conv3DNdhwcDhwcfQOp>(
|
||||||
loc, outputTensor.getType(),
|
loc, outputTensor.getType(),
|
||||||
ValueRange{paddedInput, weight, inputZp, weightZp},
|
ValueRange{paddedInput, weight, inputZp, weightZp},
|
||||||
outputTensor, stridesAttr, dilationAttr)
|
outputTensor, stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> outPerms;
|
||||||
|
outPerms.push_back(0);
|
||||||
|
outPerms.push_back(inPerms.size() - 1);
|
||||||
|
for (size_t i = 0; i < numSpatialDims; ++i) {
|
||||||
|
outPerms.push_back(i + 1);
|
||||||
|
}
|
||||||
|
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
|
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
|
||||||
};
|
};
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> outPerms;
|
|
||||||
outPerms.push_back(0);
|
|
||||||
outPerms.push_back(inPerms.size() - 1);
|
|
||||||
for (size_t i = 0; i < numSpatialDims; ++i) {
|
|
||||||
outPerms.push_back(i + 1);
|
|
||||||
}
|
|
||||||
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
|
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
if (accumulatorDType != resultDTy) {
|
if (accumulatorDType != resultDTy) {
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
|
|
|
@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128]
|
||||||
// CHECK: %[[c7:.*]] = arith.constant 7 : i32
|
// CHECK: %[[c7:.*]] = arith.constant 7 : i32
|
||||||
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
|
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
|
||||||
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
|
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
|
||||||
// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor<?x?x?x?xi8>)
|
// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
|
||||||
// CHECK-SAME: permutation = [0, 2, 3, 1]
|
// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
|
||||||
// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor<?x?x?x?xi8>)
|
|
||||||
// CHECK-SAME: permutation = [2, 3, 1, 0]
|
|
||||||
// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
|
|
||||||
// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
|
|
||||||
// CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
|
// CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
|
||||||
func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
|
|
Loading…
Reference in New Issue