MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658)

This commit adds the shape info for the tensors created during the
decomposition of GroupNorm op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3636/merge
Vivek Khandelwal 2024-08-22 21:20:40 +05:30 committed by GitHub
parent a980130676
commit fcc5f444cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 24 deletions

View File

@ -6233,7 +6233,6 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
LogicalResult matchAndRewrite(AtenGroupNormOp op, LogicalResult matchAndRewrite(AtenGroupNormOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.getInput(); Value input = op.getInput();
Value weight = op.getWeight(); Value weight = op.getWeight();
@ -6241,11 +6240,23 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
Value numGroups = op.getNumGroups(); Value numGroups = op.getNumGroups();
Value eps = op.getEps(); Value eps = op.getEps();
int64_t numGroupsInt;
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: num_groups must be a constant int");
Value cstZero = Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne = Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(op, "input should have sizes.");
SmallVector<int64_t> baseTypeSizes{inputType.getSizes()[0], numGroupsInt};
auto baseType = inputType.getWithSizesAndDtype(
baseTypeSizes, inputType.getOptionalDtype());
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero); Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne); Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
@ -6299,7 +6310,6 @@ class DecomposeAtenNativeGroupNormOp
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true); Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false); Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
// GroupNorm requires the channel dimension (C) to be exactly divisible by // GroupNorm requires the channel dimension (C) to be exactly divisible by
// the number of groups. // the number of groups.
@ -6313,12 +6323,34 @@ class DecomposeAtenNativeGroupNormOp
"the number of groups")); "the number of groups"));
// Reshape the input tensor to (N, numGroups, -1) to apply normalization. // Reshape the input tensor to (N, numGroups, -1) to apply normalization.
int64_t numGroupsInt;
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: num_groups must be a constant int");
SmallVector<Value> newShape; SmallVector<Value> newShape;
SmallVector<int64_t> inputShapeInt{inputType.getSizes()};
SmallVector<int64_t> reshapeInputShape{inputShapeInt[0], numGroupsInt};
int64_t reshapeInputLastDim = 1;
for (size_t i = 1; i < inputShapeInt.size(); i++) {
if (inputShapeInt[i] == Torch::kUnknownSize) {
reshapeInputLastDim = Torch::kUnknownSize;
break;
}
reshapeInputLastDim *= inputShapeInt[i];
}
reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize
? reshapeInputLastDim
: reshapeInputLastDim / numGroupsInt;
reshapeInputShape.push_back(reshapeInputLastDim);
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero)); newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
newShape.push_back(numGroups); newShape.push_back(numGroups);
newShape.push_back(cstNegtiveOne); newShape.push_back(cstNegtiveOne);
Type reshapeInputType = inputType.getWithSizesAndDtype(
reshapeInputShape, inputType.getOptionalDtype());
Value reshapedInput = rewriter.create<AtenViewOp>( Value reshapedInput = rewriter.create<AtenViewOp>(
loc, baseType, input, loc, reshapeInputType, input,
rewriter.create<PrimListConstructOp>( rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(IntType::get(context)), newShape)); loc, Torch::ListType::get(IntType::get(context)), newShape));
@ -6327,21 +6359,28 @@ class DecomposeAtenNativeGroupNormOp
Value dimList = rewriter.create<PrimListConstructOp>( Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
ArrayRef<Value>{cstNegtiveOne}); ArrayRef<Value>{cstNegtiveOne});
auto mean = rewriter.create<AtenMeanDimOp>(
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, reshapeInputShape[2] = 1;
/*dtype=*/none); Type reductionType = inputType.getWithSizesAndDtype(
auto var = rewriter.create<AtenVarDimOp>( reshapeInputShape, inputType.getOptionalDtype());
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, auto mean =
/*keepdim=*/cstTrue); rewriter.create<AtenMeanDimOp>(loc, reductionType, reshapedInput,
/*dims=*/dimList, /*keepdim=*/cstTrue,
/*dtype=*/none);
auto var =
rewriter.create<AtenVarDimOp>(loc, reductionType, reshapedInput,
/*dims=*/dimList, /*unbiased=*/cstFalse,
/*keepdim=*/cstTrue);
// Compute the normalized output: (input - mean) * rsqrt(var + eps) // Compute the normalized output: (input - mean) * rsqrt(var + eps)
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps, auto varPlusEps =
/*alpha=*/cstOne); rewriter.create<AtenAddScalarOp>(loc, reductionType, var, eps,
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps); /*alpha=*/cstOne);
auto invStd = rewriter.create<AtenRsqrtOp>(loc, reductionType, varPlusEps);
auto inputSubMean = rewriter.create<AtenSubTensorOp>( auto inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne);
auto normalizedOutput = auto normalizedOutput = rewriter.create<AtenMulTensorOp>(
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd); loc, reshapeInputType, inputSubMean, invStd);
// Reshape normalized output back to the original input shape // Reshape normalized output back to the original input shape
auto inputShape = rewriter.create<AtenSizeOp>( auto inputShape = rewriter.create<AtenSizeOp>(
@ -6352,22 +6391,26 @@ class DecomposeAtenNativeGroupNormOp
// Apply weight and bias if they are not None // Apply weight and bias if they are not None
// Reshape weight and bias to C,1,1,... // Reshape weight and bias to C,1,1,...
SmallVector<Value> viewShape = {channel}; SmallVector<Value> viewShape = {channel};
SmallVector<int64_t> viewShapeInt{inputShapeInt[1]};
for (unsigned i = 2; i < inputType.getSizes().size(); i++) { for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
viewShape.push_back(cstOne); viewShape.push_back(cstOne);
viewShapeInt.push_back(1);
} }
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>( Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), viewShape); loc, ListType::get(IntType::get(context)), viewShape);
Type viewType = inputType.getWithSizesAndDtype(
viewShapeInt, inputType.getOptionalDtype());
Value groupNormOutput = reshapedOutput; Value groupNormOutput = reshapedOutput;
if (!isa<Torch::NoneType>(weight.getType())) { if (!isa<Torch::NoneType>(weight.getType())) {
auto weightReshaped = rewriter.create<AtenViewOp>( auto weightReshaped = rewriter.create<AtenViewOp>(
loc, baseType, weight, /*shape=*/viewShapeSizeList); loc, viewType, weight, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenMulTensorOp>( groupNormOutput = rewriter.create<AtenMulTensorOp>(
loc, inputType, groupNormOutput, weightReshaped); loc, inputType, groupNormOutput, weightReshaped);
} }
if (!isa<Torch::NoneType>(bias.getType())) { if (!isa<Torch::NoneType>(bias.getType())) {
auto biasReshaped = rewriter.create<AtenViewOp>( auto biasReshaped = rewriter.create<AtenViewOp>(
loc, baseType, bias, /*shape=*/viewShapeSizeList); loc, viewType, bias, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenAddTensorOp>( groupNormOutput = rewriter.create<AtenAddTensorOp>(
loc, inputType, groupNormOutput, biasReshaped, loc, inputType, groupNormOutput, biasReshaped,
/*alpha=*/cstOne); /*alpha=*/cstOne);

View File

@ -1626,25 +1626,25 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
// ----- // -----
// CHECK-LABEL: func.func @test_group_normalization // CHECK-LABEL: func.func @test_group_normalization
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 // CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32> return %0 : !torch.vtensor<[3,4,2,2],f32>
} }
// ----- // -----
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821 // CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32> return %0 : !torch.vtensor<[3,4,2,2],f32>
} }