mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix conversion for depthwise convolutions (#2398)
* [TOSA] Fix conversion for depthwise convolutions * Add e2e tests for depthwise and grouped convolutions Co-authored-by: Lucas Camphausen <lucas.camphausen@iml.fraunhofer.de>pull/2403/head
parent
594a1fa471
commit
d77b9cf7ae
|
@ -13,7 +13,11 @@
|
|||
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
from torch_mlir._version import torch_version_for_comparison, version
|
||||
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier"
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
@ -276,6 +280,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
|
||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
|
@ -640,6 +648,10 @@ STABLEHLO_PASS_SET = {
|
|||
"AvgPool1dStaticModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
"Convolution2DStaticModule_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
|
@ -989,6 +1001,8 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseIsnanModule_basic",
|
||||
"TypePromotionAlphaWiderModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
|
|
|
@ -1898,6 +1898,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
auto biasElemTy =
|
||||
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();
|
||||
|
||||
int64_t groups;
|
||||
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const group size unsupported");
|
||||
} else if (groups != 1 && weightShape[1] != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "group size must be 1 (convolution) or weight.dim(1) must be 1 "
|
||||
"(depthwise convolution)");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> stride;
|
||||
if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride)))
|
||||
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");
|
||||
|
@ -1918,7 +1927,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dilation list unsupported");
|
||||
|
||||
// TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose.
|
||||
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
|
||||
// Perform the necessary transformations.
|
||||
std::optional<Value> nchwToNhwcTransposeConst =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op,
|
||||
/*vec=*/{0, 2, 3, 1},
|
||||
|
@ -1935,26 +1945,82 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
nchwToNhwcTransposeConst.value())
|
||||
.getResult();
|
||||
|
||||
SmallVector<int64_t> transposedWeightShape(
|
||||
{weightShape[0], weightShape[2], weightShape[3], weightShape[1]});
|
||||
auto transposedWeightType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
|
||||
auto transposedWeight =
|
||||
rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transposedWeightType), weight,
|
||||
nchwToNhwcTransposeConst.value())
|
||||
.getResult();
|
||||
SmallVector<int64_t> transformedWeightShape;
|
||||
RankedTensorType transformedWeightType;
|
||||
Value transformedWeight;
|
||||
int64_t outputCDim;
|
||||
if (groups == 1) {
|
||||
// full convolution: O(I/G)HW-> OHWI
|
||||
transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3],
|
||||
weightShape[1]};
|
||||
transformedWeightType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
|
||||
transformedWeight =
|
||||
rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transformedWeightType), weight,
|
||||
nchwToNhwcTransposeConst.value())
|
||||
.getResult();
|
||||
outputCDim = transformedWeightShape[0];
|
||||
} else if (weightShape[1] == 1) {
|
||||
// depthwise convolution: O(I/G)HW-> HWIM)
|
||||
// transpose: O(I/G)HW -> HWO(I/G)
|
||||
std::optional<Value> transposeConst =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op,
|
||||
/*vec=*/{2, 3, 0, 1},
|
||||
/*shape=*/{static_cast<int32_t>(4)});
|
||||
SmallVector<int64_t> transposedWeightShape = {
|
||||
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
|
||||
auto transposedWeightType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
|
||||
auto transposedWeight =
|
||||
rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transposedWeightType), weight,
|
||||
transposeConst.value())
|
||||
.getResult();
|
||||
|
||||
// reshape: HWO(I/G) -> HWIM
|
||||
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
|
||||
if (outputCDim == kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "number of output channels must be statically known for "
|
||||
"depthwise convolutions");
|
||||
}
|
||||
transformedWeightShape = {
|
||||
transposedWeightShape[0],
|
||||
transposedWeightShape[1],
|
||||
groups,
|
||||
outputCDim / groups,
|
||||
};
|
||||
transformedWeightType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
|
||||
transformedWeight =
|
||||
rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transformedWeightType),
|
||||
transposedWeight,
|
||||
rewriter.getDenseI64ArrayAttr(transformedWeightShape))
|
||||
.getResult();
|
||||
} else {
|
||||
llvm_unreachable("Unhandled convolution type");
|
||||
}
|
||||
|
||||
int64_t outputHDim, outputWDim;
|
||||
if (inputTy.hasStaticShape()) {
|
||||
outputHDim = (transposedInputShape[1] + padding[0] + padding[1] -
|
||||
dilation[0] * (transposedWeightShape[1] - 1) - 1) /
|
||||
int64_t inputHDim = inputShape[2];
|
||||
int64_t inputWDim = inputShape[3];
|
||||
int64_t weightHDim = weightShape[2];
|
||||
int64_t weightWDim = weightShape[3];
|
||||
outputHDim = (inputHDim + padding[0] + padding[1] -
|
||||
dilation[0] * (weightHDim - 1) - 1) /
|
||||
stride[0] +
|
||||
1;
|
||||
outputWDim = (transposedInputShape[2] + padding[2] + padding[3] -
|
||||
dilation[1] * (transposedWeightShape[2] - 1) - 1) /
|
||||
outputWDim = (inputWDim + padding[2] + padding[3] -
|
||||
dilation[1] * (weightWDim - 1) - 1) /
|
||||
stride[1] +
|
||||
1;
|
||||
} else {
|
||||
|
@ -1965,19 +2031,36 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
// Output shape is NHWC, to be transposed back to NCHW. Output elemTy for
|
||||
// quantized input is i32, which gets rescaled down to quantized output range.
|
||||
SmallVector<int64_t> outputShape = {transposedInputShape[0], outputHDim,
|
||||
outputWDim, transposedWeightShape[0]};
|
||||
outputWDim, outputCDim};
|
||||
auto convOpTy =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy);
|
||||
|
||||
Value convOpResult =
|
||||
rewriter
|
||||
.create<tosa::Conv2DOp>(op->getLoc(),
|
||||
getTypeConverter()->convertType(convOpTy),
|
||||
transposedInput, transposedWeight, bias,
|
||||
rewriter.getDenseI64ArrayAttr(padding),
|
||||
rewriter.getDenseI64ArrayAttr(stride),
|
||||
rewriter.getDenseI64ArrayAttr(dilation))
|
||||
.getResult();
|
||||
Value convOpResult;
|
||||
if (groups == 1) {
|
||||
// full convolution
|
||||
convOpResult =
|
||||
rewriter
|
||||
.create<tosa::Conv2DOp>(op->getLoc(),
|
||||
getTypeConverter()->convertType(convOpTy),
|
||||
transposedInput, transformedWeight, bias,
|
||||
rewriter.getDenseI64ArrayAttr(padding),
|
||||
rewriter.getDenseI64ArrayAttr(stride),
|
||||
rewriter.getDenseI64ArrayAttr(dilation))
|
||||
.getResult();
|
||||
} else if (weightShape[1] == 1) {
|
||||
// depthwise convolution
|
||||
convOpResult =
|
||||
rewriter
|
||||
.create<tosa::DepthwiseConv2DOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(convOpTy),
|
||||
transposedInput, transformedWeight, bias,
|
||||
rewriter.getDenseI64ArrayAttr(padding),
|
||||
rewriter.getDenseI64ArrayAttr(stride),
|
||||
rewriter.getDenseI64ArrayAttr(dilation))
|
||||
.getResult();
|
||||
} else {
|
||||
llvm_unreachable("Unhandled convolution type");
|
||||
}
|
||||
|
||||
std::optional<Value> nhwcToNchwTransposeConst =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op,
|
||||
|
|
|
@ -112,32 +112,56 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
|
|||
|
||||
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, out_channels, groups):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(in_channels=2,
|
||||
out_channels=10,
|
||||
self.conv = torch.nn.Conv2d(in_channels=4,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=3,
|
||||
stride=2,
|
||||
dilation=3,
|
||||
bias=False)
|
||||
bias=False,
|
||||
groups=groups)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 2, 10, 20], torch.float32, True),
|
||||
([5, 4, 10, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1))
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
module.forward(tu.rand(5, 4, 10, 20))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4))
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 10, 20))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4))
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 10, 20))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2))
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 10, 20))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2))
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 10, 20))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue