mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where 1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W 2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W Now this has been fixed in https://github.com/llvm/llvm-project/pull/73855 which broke the torch-mlir lowering to that Op. This patch switches lowering in torch-mlir to the newly introduced `linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that is compatible with PyTorch's memory layout. Fix https://github.com/llvm/torch-mlir/issues/2622pull/2626/head snapshot-20231209.1047
parent
8252656b6d
commit
fb21a85874
|
@ -848,6 +848,7 @@ public:
|
||||||
indices);
|
indices);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// expand F,C,H,W -> G,F/G,C,H,W
|
||||||
auto expandWeight = [&](Value tensor) {
|
auto expandWeight = [&](Value tensor) {
|
||||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
auto inType = tensor.getType().cast<RankedTensorType>();
|
||||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||||
|
@ -868,21 +869,19 @@ public:
|
||||||
|
|
||||||
Value paddedInputExpanded = expandGroups(paddedInput, 1);
|
Value paddedInputExpanded = expandGroups(paddedInput, 1);
|
||||||
Value weightExpanded = expandWeight(weight);
|
Value weightExpanded = expandWeight(weight);
|
||||||
Value outputTensorExpanded = expandGroups(outputTensor, 1);
|
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
||||||
|
|
||||||
// TODO: add 1D and 3D case
|
// TODO: add 1D and 3D case
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::Conv2DNgchwFgchwOp>(
|
.create<linalg::Conv2DNgchwGfchwOp>(
|
||||||
loc, outputTensorExpanded.getType(),
|
loc, expandOutputTensor.getResultType(),
|
||||||
ValueRange{paddedInputExpanded, weightExpanded},
|
ValueRange{paddedInputExpanded, weightExpanded},
|
||||||
outputTensorExpanded, stridesAttr, dilationAttr)
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
|
|
||||||
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
|
|
||||||
indices.push_back({dim});
|
|
||||||
conv = rewriter.create<tensor::CollapseShapeOp>(
|
conv = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
loc, outputTensor.getType(), conv, indices);
|
loc, outputTensor.getType(), conv,
|
||||||
|
expandOutputTensor.getReassociationIndices());
|
||||||
}
|
}
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
|
@ -23,14 +23,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
"IscloseStaticModuleTrue_basic"
|
"IscloseStaticModuleTrue_basic"
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
|
|
||||||
LINALG_XFAIL_SET |= {
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
|
||||||
"ConvolutionModule2DGroups_basic",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TORCHDYNAMO_XFAIL_SET = {
|
TORCHDYNAMO_XFAIL_SET = {
|
||||||
#### General TorchDynamo/PyTorch errors
|
#### General TorchDynamo/PyTorch errors
|
||||||
|
|
||||||
|
@ -316,13 +308,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
|
|
||||||
TORCHDYNAMO_XFAIL_SET |= {
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
|
|
||||||
"ConvolutionModule2DGroups_basic",
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCHDYNAMO_CRASHING_SET = {
|
TORCHDYNAMO_CRASHING_SET = {
|
||||||
# No upstream decompositions.
|
# No upstream decompositions.
|
||||||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||||
|
|
Loading…
Reference in New Issue