[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)

The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where

1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W

Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.

This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.

Fix https://github.com/llvm/torch-mlir/issues/2622
pull/2626/head snapshot-20231209.1047
Felix Schneider 2023-12-08 14:18:23 +01:00 committed by GitHub
parent 8252656b6d
commit fb21a85874
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 23 deletions

View File

@ -848,6 +848,7 @@ public:
indices); indices);
}; };
// expand F,C,H,W -> G,F/G,C,H,W
auto expandWeight = [&](Value tensor) { auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>(); auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = makeShapeTorchCompatible(inType.getShape()); auto inShape = makeShapeTorchCompatible(inType.getShape());
@ -868,21 +869,19 @@ public:
Value paddedInputExpanded = expandGroups(paddedInput, 1); Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight); Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1); auto expandOutputTensor = expandGroups(outputTensor, 1);
// TODO: add 1D and 3D case // TODO: add 1D and 3D case
conv = rewriter conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>( .create<linalg::Conv2DNgchwGfchwOp>(
loc, outputTensorExpanded.getType(), loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded}, ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr) expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>( conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices); loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
} }
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());

View File

@ -23,14 +23,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"IscloseStaticModuleTrue_basic" "IscloseStaticModuleTrue_basic"
} }
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
LINALG_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}
TORCHDYNAMO_XFAIL_SET = { TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors #### General TorchDynamo/PyTorch errors
@ -316,13 +308,6 @@ TORCHDYNAMO_XFAIL_SET = {
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
} }
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
TORCHDYNAMO_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}
TORCHDYNAMO_CRASHING_SET = { TORCHDYNAMO_CRASHING_SET = {
# No upstream decompositions. # No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)