mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] add support for depthwise qconv (#3564)
- Adds support for lowering depthwise + quantized convolution ops to linalg::DepthwiseConv2DNhwcHwcQOp - Changed the variable name for groupSize (which is really C/G) to the more appropriate numGroups (G). - Discovered in e2e testing that linalg does not accept (Cin = groups && Cout = K*groups for K>1) as a "depthwise" conv, so this also updates the case-checking to reflect this issue.pull/3572/head
parent
50d6ce225f
commit
f1c74e1431
|
@ -788,7 +788,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
||||
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
|
||||
Value weight = adaptor.getWeight(); /* in form of F*C/G*H*W */
|
||||
Value bias = adaptor.getBias();
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
|
||||
|
@ -898,8 +898,8 @@ public:
|
|||
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
|
||||
|
||||
// Checks for valid group size
|
||||
int64_t groupSize;
|
||||
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groupSize)))
|
||||
int64_t numGroups;
|
||||
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only constant group size supported.");
|
||||
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
|
||||
|
@ -1118,14 +1118,14 @@ public:
|
|||
|
||||
Value conv;
|
||||
// the code so far is able to respect all numSpatialDims
|
||||
// the code below this point is numSpatialDims specific and groupSize
|
||||
// the code below this point is numSpatialDims specific and numGroups
|
||||
// specific
|
||||
// TODO: factor out the above code into a helper function, and then separate
|
||||
// convolution into:
|
||||
// - grouped 1d-3d
|
||||
// - grouped 1d-3d (quantized)
|
||||
// - ungrouped 1d-3d
|
||||
if (groupSize == 1 && !inputZp) {
|
||||
if (numGroups == 1 && !inputZp) {
|
||||
switch (numSpatialDims) {
|
||||
case 1:
|
||||
conv = rewriter
|
||||
|
@ -1166,7 +1166,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (groupSize == 1 && inputZp) {
|
||||
if (numGroups == 1 && inputZp) {
|
||||
// The quantized version uses a different channel ordering so we need to
|
||||
// permute the tensors in order to use the existing path. We should
|
||||
// eventually directly support this channel ordering.
|
||||
|
@ -1230,30 +1230,66 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
// Special depthwise case
|
||||
// Special depthwise case: Cin = Cout = groups.
|
||||
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
||||
// of groups) to be depthwise in their documentation, but the linalg ops
|
||||
// apparently disagree.
|
||||
auto inShape = makeShapeTorchCompatible(
|
||||
cast<RankedTensorType>(input.getType()).getShape());
|
||||
auto weightShape = makeShapeTorchCompatible(
|
||||
cast<RankedTensorType>(weight.getType()).getShape());
|
||||
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
|
||||
// Collapse weight shape
|
||||
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
||||
weightShape[1] == 1) {
|
||||
// Collapse weight shape (C/G == 1)
|
||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
||||
SmallVector<int64_t> collapsedShape{
|
||||
(weightShape[0] == kUnknownSize ? kUnknownSize
|
||||
: weightShape[0] * weightShape[1]),
|
||||
weightShape[2], weightShape[3]};
|
||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
||||
weightShape[2], weightShape[3]};
|
||||
Type collapsedType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, collapsedType, weight, collapsedDims);
|
||||
if (!inputZp) {
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
} else {
|
||||
// currently, the only named depthwise qconv op is nhwc_hwc
|
||||
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
||||
// linalg conv result nhwc -> nchw
|
||||
// inPerms = [0, 2, 3, 1]
|
||||
// weightPerms = [1, 2, 0]
|
||||
// resultPerms = [0, 3, 1, 2]
|
||||
llvm::SmallVector<int64_t> inPerms, weightPerms, resultPerms;
|
||||
inPerms.push_back(0);
|
||||
resultPerms.append({0, static_cast<int64_t>(numSpatialDims + 1)});
|
||||
for (size_t i = 0; i < numSpatialDims; ++i) {
|
||||
inPerms.push_back(i + 2);
|
||||
weightPerms.push_back(i + 1);
|
||||
resultPerms.push_back(i + 1);
|
||||
}
|
||||
inPerms.push_back(1);
|
||||
weightPerms.push_back(0);
|
||||
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
paddedInput =
|
||||
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
|
||||
collapsedWeight =
|
||||
transposeValue(op.getLoc(), collapsedWeight, weightPerms, rewriter);
|
||||
outputTensor =
|
||||
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
|
||||
|
||||
conv =
|
||||
rewriter
|
||||
.create<linalg::DepthwiseConv2DNhwcHwcQOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight, inputZp, weightZp},
|
||||
outputTensor, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
// convert output nhwc -> nchw
|
||||
conv = transposeValue(op.getLoc(), conv, resultPerms, rewriter);
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
if (accumulatorDType != resultDTy) {
|
||||
|
@ -1274,12 +1310,12 @@ public:
|
|||
SmallVector<int64_t> outShape;
|
||||
for (auto i = 0; i < (long)inShape.size(); i++) {
|
||||
if (i == 1) {
|
||||
outShape.push_back(groupSize);
|
||||
outShape.push_back(numGroups);
|
||||
}
|
||||
if (i == (long)dim) {
|
||||
outShape.push_back(inShape[i] == kUnknownSize
|
||||
? kUnknownSize
|
||||
: inShape[i] / groupSize);
|
||||
: inShape[i] / numGroups);
|
||||
} else {
|
||||
outShape.push_back(inShape[i]);
|
||||
}
|
||||
|
@ -1305,8 +1341,8 @@ public:
|
|||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||
|
||||
SmallVector<int64_t> outShape{
|
||||
groupSize,
|
||||
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)};
|
||||
numGroups,
|
||||
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
|
||||
outShape.append(inShape.begin() + 1, inShape.end());
|
||||
|
||||
SmallVector<ReassociationIndices> indices{{0, 1}};
|
||||
|
|
|
@ -16,9 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version
|
|||
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
|
||||
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
||||
|
@ -250,9 +247,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ScatterValueIntModule_basic",
|
||||
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
# Lowering not present for this case
|
||||
|
@ -281,7 +275,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
# Dynamo not supporting conv_tbc
|
||||
"ConvTbcModule_basic",
|
||||
|
@ -380,8 +376,9 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
|
@ -547,7 +544,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
|
@ -2204,7 +2203,9 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
}
|
||||
|
||||
|
@ -2350,7 +2351,9 @@ ONNX_XFAIL_SET = {
|
|||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingModule_basic",
|
||||
"Conv3dModule_basic",
|
||||
|
@ -2718,7 +2721,6 @@ ONNX_XFAIL_SET = {
|
|||
"BernoulliModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
|
@ -2922,7 +2924,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
"Conv3dModule_basic",
|
||||
|
@ -3715,7 +3719,9 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
"Conv2dQInt8Module_not_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
|
|
|
@ -1156,21 +1156,12 @@ def ConvTbcModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
|
||||
|
||||
|
||||
class Conv2dQInt8Module(torch.nn.Module):
|
||||
class Conv2dQInt8ModuleBase(torch.nn.Module):
|
||||
def __init__(self, groups=1):
|
||||
self.groups = groups
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight, bias):
|
||||
def _forward(self, inputVec, weight, bias):
|
||||
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
|
||||
inputVec = torch.dequantize(inputVec)
|
||||
|
||||
|
@ -1191,7 +1182,49 @@ class Conv2dQInt8Module(torch.nn.Module):
|
|||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8Module())
|
||||
class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight, bias):
|
||||
return self._forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 12, 12], torch.int8, True),
|
||||
([3, 1, 5, 3], torch.int8, True),
|
||||
([3], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight, bias):
|
||||
return self._forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 12, 12], torch.int8, True),
|
||||
([6, 1, 5, 3], torch.int8, True),
|
||||
([6], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight, bias):
|
||||
return self._forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn())
|
||||
def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
||||
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
|
||||
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||
|
@ -1199,7 +1232,7 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
|||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2))
|
||||
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
||||
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
|
||||
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||
|
@ -1207,6 +1240,24 @@ def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
|||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3))
|
||||
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
|
||||
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
||||
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
||||
bias = torch.rand(3)
|
||||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dQInt8ModuleStatic_MoreOutChannels(groups=3)
|
||||
)
|
||||
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
|
||||
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
||||
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
||||
bias = torch.rand(6)
|
||||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue