mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where 1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W 2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W Now this has been fixed in https://github.com/llvm/llvm-project/pull/73855 which broke the torch-mlir lowering to that Op. This patch switches lowering in torch-mlir to the newly introduced `linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that is compatible with PyTorch's memory layout. Fix https://github.com/llvm/torch-mlir/issues/2622pull/2626/head snapshot-20231209.1047
parent
8252656b6d
commit
fb21a85874
|
@ -848,6 +848,7 @@ public:
|
|||
indices);
|
||||
};
|
||||
|
||||
// expand F,C,H,W -> G,F/G,C,H,W
|
||||
auto expandWeight = [&](Value tensor) {
|
||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||
|
@ -868,21 +869,19 @@ public:
|
|||
|
||||
Value paddedInputExpanded = expandGroups(paddedInput, 1);
|
||||
Value weightExpanded = expandWeight(weight);
|
||||
Value outputTensorExpanded = expandGroups(outputTensor, 1);
|
||||
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
||||
|
||||
// TODO: add 1D and 3D case
|
||||
conv = rewriter
|
||||
.create<linalg::Conv2DNgchwFgchwOp>(
|
||||
loc, outputTensorExpanded.getType(),
|
||||
.create<linalg::Conv2DNgchwGfchwOp>(
|
||||
loc, expandOutputTensor.getResultType(),
|
||||
ValueRange{paddedInputExpanded, weightExpanded},
|
||||
outputTensorExpanded, stridesAttr, dilationAttr)
|
||||
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
|
||||
SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
|
||||
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
|
||||
indices.push_back({dim});
|
||||
conv = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, outputTensor.getType(), conv, indices);
|
||||
loc, outputTensor.getType(), conv,
|
||||
expandOutputTensor.getReassociationIndices());
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
|
|
|
@ -23,14 +23,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"IscloseStaticModuleTrue_basic"
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
|
||||
LINALG_XFAIL_SET |= {
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
"ConvolutionModule2DGroups_basic",
|
||||
}
|
||||
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
||||
|
@ -316,13 +308,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ArangeStartOutViewModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
|
||||
TORCHDYNAMO_XFAIL_SET |= {
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
||||
"ConvolutionModule2DGroups_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
# No upstream decompositions.
|
||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||
|
|
Loading…
Reference in New Issue