From fcc5f444cd81c780b1dbea3ffb39a54e83bebfe1 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 22 Aug 2024 21:20:40 +0530 Subject: [PATCH] MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658) This commit adds the shape info for the tensors created during the decomposition of GroupNorm op. Signed-Off By: Vivek Khandelwal --- .../Torch/Transforms/DecomposeComplexOps.cpp | 79 ++++++++++++++----- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 12 +-- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a4eb6dcff..af90280d7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6233,7 +6233,6 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenGroupNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); @@ -6241,11 +6240,23 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { Value numGroups = op.getNumGroups(); Value eps = op.getEps(); + int64_t numGroupsInt; + if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_groups must be a constant int"); + Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + auto inputType = cast(input.getType()); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure(op, "input should have sizes."); + + SmallVector baseTypeSizes{inputType.getSizes()[0], numGroupsInt}; + auto baseType = inputType.getWithSizesAndDtype( + baseTypeSizes, inputType.getOptionalDtype()); Value N = rewriter.create(loc, input, cstZero); Value C = rewriter.create(loc, input, cstOne); @@ -6299,7 +6310,6 @@ class DecomposeAtenNativeGroupNormOp rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); Value cstTrue = rewriter.create(loc, true); Value cstFalse = rewriter.create(loc, false); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); // GroupNorm requires the channel dimension (C) to be exactly divisible by // the number of groups. @@ -6313,12 +6323,34 @@ class DecomposeAtenNativeGroupNormOp "the number of groups")); // Reshape the input tensor to (N, numGroups, -1) to apply normalization. + int64_t numGroupsInt; + if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_groups must be a constant int"); + SmallVector newShape; + SmallVector inputShapeInt{inputType.getSizes()}; + SmallVector reshapeInputShape{inputShapeInt[0], numGroupsInt}; + int64_t reshapeInputLastDim = 1; + for (size_t i = 1; i < inputShapeInt.size(); i++) { + if (inputShapeInt[i] == Torch::kUnknownSize) { + reshapeInputLastDim = Torch::kUnknownSize; + break; + } + reshapeInputLastDim *= inputShapeInt[i]; + } + reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize + ? reshapeInputLastDim + : reshapeInputLastDim / numGroupsInt; + reshapeInputShape.push_back(reshapeInputLastDim); + newShape.push_back(rewriter.create(loc, input, cstZero)); newShape.push_back(numGroups); newShape.push_back(cstNegtiveOne); + Type reshapeInputType = inputType.getWithSizesAndDtype( + reshapeInputShape, inputType.getOptionalDtype()); Value reshapedInput = rewriter.create( - loc, baseType, input, + loc, reshapeInputType, input, rewriter.create( loc, Torch::ListType::get(IntType::get(context)), newShape)); @@ -6327,21 +6359,28 @@ class DecomposeAtenNativeGroupNormOp Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ArrayRef{cstNegtiveOne}); - auto mean = rewriter.create( - loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, - /*dtype=*/none); - auto var = rewriter.create( - loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, - /*keepdim=*/cstTrue); + + reshapeInputShape[2] = 1; + Type reductionType = inputType.getWithSizesAndDtype( + reshapeInputShape, inputType.getOptionalDtype()); + auto mean = + rewriter.create(loc, reductionType, reshapedInput, + /*dims=*/dimList, /*keepdim=*/cstTrue, + /*dtype=*/none); + auto var = + rewriter.create(loc, reductionType, reshapedInput, + /*dims=*/dimList, /*unbiased=*/cstFalse, + /*keepdim=*/cstTrue); // Compute the normalized output: (input - mean) * rsqrt(var + eps) - auto varPlusEps = rewriter.create(loc, baseType, var, eps, - /*alpha=*/cstOne); - auto invStd = rewriter.create(loc, baseType, varPlusEps); + auto varPlusEps = + rewriter.create(loc, reductionType, var, eps, + /*alpha=*/cstOne); + auto invStd = rewriter.create(loc, reductionType, varPlusEps); auto inputSubMean = rewriter.create( - loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); - auto normalizedOutput = - rewriter.create(loc, baseType, inputSubMean, invStd); + loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne); + auto normalizedOutput = rewriter.create( + loc, reshapeInputType, inputSubMean, invStd); // Reshape normalized output back to the original input shape auto inputShape = rewriter.create( @@ -6352,22 +6391,26 @@ class DecomposeAtenNativeGroupNormOp // Apply weight and bias if they are not None // Reshape weight and bias to C,1,1,... SmallVector viewShape = {channel}; + SmallVector viewShapeInt{inputShapeInt[1]}; for (unsigned i = 2; i < inputType.getSizes().size(); i++) { viewShape.push_back(cstOne); + viewShapeInt.push_back(1); } Value viewShapeSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), viewShape); + Type viewType = inputType.getWithSizesAndDtype( + viewShapeInt, inputType.getOptionalDtype()); Value groupNormOutput = reshapedOutput; if (!isa(weight.getType())) { auto weightReshaped = rewriter.create( - loc, baseType, weight, /*shape=*/viewShapeSizeList); + loc, viewType, weight, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, weightReshaped); } if (!isa(bias.getType())) { auto biasReshaped = rewriter.create( - loc, baseType, bias, /*shape=*/viewShapeSizeList); + loc, viewType, bias, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, biasReshaped, /*alpha=*/cstOne); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index eaaff8d26..f291a5991 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1626,25 +1626,25 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1 // ----- // CHECK-LABEL: func.func @test_group_normalization -func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> - %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> return %0 : !torch.vtensor<[3,4,2,2],f32> } // ----- -func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821 // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> - %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> return %0 : !torch.vtensor<[3,4,2,2],f32> }