[TOSA] Fix conversion for depthwise convolutions (#2398)

* [TOSA] Fix conversion for depthwise convolutions

* Add e2e tests for depthwise and grouped convolutions

Co-authored-by: Lucas Camphausen <lucas.camphausen@iml.fraunhofer.de>
pull/2403/head
Simon Camphausen 2023-08-18 17:15:54 +02:00 committed by GitHub
parent 594a1fa471
commit d77b9cf7ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 35 deletions

View File

@ -13,7 +13,11 @@
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
from torch_mlir._version import torch_version_for_comparison, version
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier"
}
TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors
@ -276,6 +280,10 @@ TORCHDYNAMO_XFAIL_SET = {
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
}
TORCHDYNAMO_CRASHING_SET = {
@ -640,6 +648,10 @@ STABLEHLO_PASS_SET = {
"AvgPool1dStaticModule_basic",
"AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
@ -989,6 +1001,8 @@ TOSA_PASS_SET = {
"ElementwiseIsnanModule_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",

View File

@ -1898,6 +1898,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto biasElemTy =
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();
int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
return rewriter.notifyMatchFailure(op, "non-const group size unsupported");
} else if (groups != 1 && weightShape[1] != 1) {
return rewriter.notifyMatchFailure(
op, "group size must be 1 (convolution) or weight.dim(1) must be 1 "
"(depthwise convolution)");
}
SmallVector<int64_t, 2> stride;
if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride)))
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");
@ -1918,7 +1927,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");
// TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose.
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
// Perform the necessary transformations.
std::optional<Value> nchwToNhwcTransposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{0, 2, 3, 1},
@ -1935,26 +1945,82 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
nchwToNhwcTransposeConst.value())
.getResult();
SmallVector<int64_t> transposedWeightShape(
{weightShape[0], weightShape[2], weightShape[3], weightShape[1]});
auto transposedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
auto transposedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedWeightType), weight,
nchwToNhwcTransposeConst.value())
.getResult();
SmallVector<int64_t> transformedWeightShape;
RankedTensorType transformedWeightType;
Value transformedWeight;
int64_t outputCDim;
if (groups == 1) {
// full convolution: O(I/G)HW-> OHWI
transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3],
weightShape[1]};
transformedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
transformedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transformedWeightType), weight,
nchwToNhwcTransposeConst.value())
.getResult();
outputCDim = transformedWeightShape[0];
} else if (weightShape[1] == 1) {
// depthwise convolution: O(I/G)HW-> HWIM)
// transpose: O(I/G)HW -> HWO(I/G)
std::optional<Value> transposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{2, 3, 0, 1},
/*shape=*/{static_cast<int32_t>(4)});
SmallVector<int64_t> transposedWeightShape = {
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
auto transposedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
auto transposedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedWeightType), weight,
transposeConst.value())
.getResult();
// reshape: HWO(I/G) -> HWIM
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
if (outputCDim == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "number of output channels must be statically known for "
"depthwise convolutions");
}
transformedWeightShape = {
transposedWeightShape[0],
transposedWeightShape[1],
groups,
outputCDim / groups,
};
transformedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
transformedWeight =
rewriter
.create<tosa::ReshapeOp>(
op->getLoc(),
getTypeConverter()->convertType(transformedWeightType),
transposedWeight,
rewriter.getDenseI64ArrayAttr(transformedWeightShape))
.getResult();
} else {
llvm_unreachable("Unhandled convolution type");
}
int64_t outputHDim, outputWDim;
if (inputTy.hasStaticShape()) {
outputHDim = (transposedInputShape[1] + padding[0] + padding[1] -
dilation[0] * (transposedWeightShape[1] - 1) - 1) /
int64_t inputHDim = inputShape[2];
int64_t inputWDim = inputShape[3];
int64_t weightHDim = weightShape[2];
int64_t weightWDim = weightShape[3];
outputHDim = (inputHDim + padding[0] + padding[1] -
dilation[0] * (weightHDim - 1) - 1) /
stride[0] +
1;
outputWDim = (transposedInputShape[2] + padding[2] + padding[3] -
dilation[1] * (transposedWeightShape[2] - 1) - 1) /
outputWDim = (inputWDim + padding[2] + padding[3] -
dilation[1] * (weightWDim - 1) - 1) /
stride[1] +
1;
} else {
@ -1965,19 +2031,36 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// Output shape is NHWC, to be transposed back to NCHW. Output elemTy for
// quantized input is i32, which gets rescaled down to quantized output range.
SmallVector<int64_t> outputShape = {transposedInputShape[0], outputHDim,
outputWDim, transposedWeightShape[0]};
outputWDim, outputCDim};
auto convOpTy =
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy);
Value convOpResult =
rewriter
.create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy),
transposedInput, transposedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
Value convOpResult;
if (groups == 1) {
// full convolution
convOpResult =
rewriter
.create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
} else if (weightShape[1] == 1) {
// depthwise convolution
convOpResult =
rewriter
.create<tosa::DepthwiseConv2DOp>(
op->getLoc(), getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
} else {
llvm_unreachable("Unhandled convolution type");
}
std::optional<Value> nhwcToNchwTransposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,

View File

@ -112,32 +112,56 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
def __init__(self, out_channels, groups):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(in_channels=2,
out_channels=10,
self.conv = torch.nn.Conv2d(in_channels=4,
out_channels=out_channels,
kernel_size=3,
padding=3,
stride=2,
dilation=3,
bias=False)
bias=False,
groups=groups)
self.train(False)
@export
@annotate_args([
None,
([5, 2, 10, 20], torch.float32, True),
([5, 4, 10, 20], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1))
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
module.forward(tu.rand(5, 4, 10, 20))
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4))
def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4))
def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2))
def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2))
def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))
# ==============================================================================