mirror of https://github.com/llvm/torch-mlir
MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658)
This commit adds the shape info for the tensors created during the decomposition of GroupNorm op. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3636/merge
parent
a980130676
commit
fcc5f444cd
|
@ -6233,7 +6233,6 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
||||||
LogicalResult matchAndRewrite(AtenGroupNormOp op,
|
LogicalResult matchAndRewrite(AtenGroupNormOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
|
||||||
|
|
||||||
Value input = op.getInput();
|
Value input = op.getInput();
|
||||||
Value weight = op.getWeight();
|
Value weight = op.getWeight();
|
||||||
|
@ -6241,11 +6240,23 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
||||||
Value numGroups = op.getNumGroups();
|
Value numGroups = op.getNumGroups();
|
||||||
Value eps = op.getEps();
|
Value eps = op.getEps();
|
||||||
|
|
||||||
|
int64_t numGroupsInt;
|
||||||
|
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: num_groups must be a constant int");
|
||||||
|
|
||||||
Value cstZero =
|
Value cstZero =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
Value cstOne =
|
Value cstOne =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
||||||
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
|
if (!inputType.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(op, "input should have sizes.");
|
||||||
|
|
||||||
|
SmallVector<int64_t> baseTypeSizes{inputType.getSizes()[0], numGroupsInt};
|
||||||
|
auto baseType = inputType.getWithSizesAndDtype(
|
||||||
|
baseTypeSizes, inputType.getOptionalDtype());
|
||||||
|
|
||||||
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
||||||
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
||||||
|
@ -6299,7 +6310,6 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
||||||
|
|
||||||
// GroupNorm requires the channel dimension (C) to be exactly divisible by
|
// GroupNorm requires the channel dimension (C) to be exactly divisible by
|
||||||
// the number of groups.
|
// the number of groups.
|
||||||
|
@ -6313,12 +6323,34 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
"the number of groups"));
|
"the number of groups"));
|
||||||
|
|
||||||
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
|
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
|
||||||
|
int64_t numGroupsInt;
|
||||||
|
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: num_groups must be a constant int");
|
||||||
|
|
||||||
SmallVector<Value> newShape;
|
SmallVector<Value> newShape;
|
||||||
|
SmallVector<int64_t> inputShapeInt{inputType.getSizes()};
|
||||||
|
SmallVector<int64_t> reshapeInputShape{inputShapeInt[0], numGroupsInt};
|
||||||
|
int64_t reshapeInputLastDim = 1;
|
||||||
|
for (size_t i = 1; i < inputShapeInt.size(); i++) {
|
||||||
|
if (inputShapeInt[i] == Torch::kUnknownSize) {
|
||||||
|
reshapeInputLastDim = Torch::kUnknownSize;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reshapeInputLastDim *= inputShapeInt[i];
|
||||||
|
}
|
||||||
|
reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize
|
||||||
|
? reshapeInputLastDim
|
||||||
|
: reshapeInputLastDim / numGroupsInt;
|
||||||
|
reshapeInputShape.push_back(reshapeInputLastDim);
|
||||||
|
|
||||||
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
|
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
|
||||||
newShape.push_back(numGroups);
|
newShape.push_back(numGroups);
|
||||||
newShape.push_back(cstNegtiveOne);
|
newShape.push_back(cstNegtiveOne);
|
||||||
|
Type reshapeInputType = inputType.getWithSizesAndDtype(
|
||||||
|
reshapeInputShape, inputType.getOptionalDtype());
|
||||||
Value reshapedInput = rewriter.create<AtenViewOp>(
|
Value reshapedInput = rewriter.create<AtenViewOp>(
|
||||||
loc, baseType, input,
|
loc, reshapeInputType, input,
|
||||||
rewriter.create<PrimListConstructOp>(
|
rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(IntType::get(context)), newShape));
|
loc, Torch::ListType::get(IntType::get(context)), newShape));
|
||||||
|
|
||||||
|
@ -6327,21 +6359,28 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
ArrayRef<Value>{cstNegtiveOne});
|
ArrayRef<Value>{cstNegtiveOne});
|
||||||
auto mean = rewriter.create<AtenMeanDimOp>(
|
|
||||||
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
|
reshapeInputShape[2] = 1;
|
||||||
/*dtype=*/none);
|
Type reductionType = inputType.getWithSizesAndDtype(
|
||||||
auto var = rewriter.create<AtenVarDimOp>(
|
reshapeInputShape, inputType.getOptionalDtype());
|
||||||
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
|
auto mean =
|
||||||
/*keepdim=*/cstTrue);
|
rewriter.create<AtenMeanDimOp>(loc, reductionType, reshapedInput,
|
||||||
|
/*dims=*/dimList, /*keepdim=*/cstTrue,
|
||||||
|
/*dtype=*/none);
|
||||||
|
auto var =
|
||||||
|
rewriter.create<AtenVarDimOp>(loc, reductionType, reshapedInput,
|
||||||
|
/*dims=*/dimList, /*unbiased=*/cstFalse,
|
||||||
|
/*keepdim=*/cstTrue);
|
||||||
|
|
||||||
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
|
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
|
||||||
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
|
auto varPlusEps =
|
||||||
/*alpha=*/cstOne);
|
rewriter.create<AtenAddScalarOp>(loc, reductionType, var, eps,
|
||||||
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
|
/*alpha=*/cstOne);
|
||||||
|
auto invStd = rewriter.create<AtenRsqrtOp>(loc, reductionType, varPlusEps);
|
||||||
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
|
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
|
||||||
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
|
loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne);
|
||||||
auto normalizedOutput =
|
auto normalizedOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
|
loc, reshapeInputType, inputSubMean, invStd);
|
||||||
|
|
||||||
// Reshape normalized output back to the original input shape
|
// Reshape normalized output back to the original input shape
|
||||||
auto inputShape = rewriter.create<AtenSizeOp>(
|
auto inputShape = rewriter.create<AtenSizeOp>(
|
||||||
|
@ -6352,22 +6391,26 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
// Apply weight and bias if they are not None
|
// Apply weight and bias if they are not None
|
||||||
// Reshape weight and bias to C,1,1,...
|
// Reshape weight and bias to C,1,1,...
|
||||||
SmallVector<Value> viewShape = {channel};
|
SmallVector<Value> viewShape = {channel};
|
||||||
|
SmallVector<int64_t> viewShapeInt{inputShapeInt[1]};
|
||||||
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
|
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
|
||||||
viewShape.push_back(cstOne);
|
viewShape.push_back(cstOne);
|
||||||
|
viewShapeInt.push_back(1);
|
||||||
}
|
}
|
||||||
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
|
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
|
||||||
loc, ListType::get(IntType::get(context)), viewShape);
|
loc, ListType::get(IntType::get(context)), viewShape);
|
||||||
|
|
||||||
|
Type viewType = inputType.getWithSizesAndDtype(
|
||||||
|
viewShapeInt, inputType.getOptionalDtype());
|
||||||
Value groupNormOutput = reshapedOutput;
|
Value groupNormOutput = reshapedOutput;
|
||||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
auto weightReshaped = rewriter.create<AtenViewOp>(
|
auto weightReshaped = rewriter.create<AtenViewOp>(
|
||||||
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
loc, viewType, weight, /*shape=*/viewShapeSizeList);
|
||||||
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
loc, inputType, groupNormOutput, weightReshaped);
|
loc, inputType, groupNormOutput, weightReshaped);
|
||||||
}
|
}
|
||||||
if (!isa<Torch::NoneType>(bias.getType())) {
|
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||||
auto biasReshaped = rewriter.create<AtenViewOp>(
|
auto biasReshaped = rewriter.create<AtenViewOp>(
|
||||||
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
loc, viewType, bias, /*shape=*/viewShapeSizeList);
|
||||||
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
loc, inputType, groupNormOutput, biasReshaped,
|
loc, inputType, groupNormOutput, biasReshaped,
|
||||||
/*alpha=*/cstOne);
|
/*alpha=*/cstOne);
|
||||||
|
|
|
@ -1626,25 +1626,25 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_group_normalization
|
// CHECK-LABEL: func.func @test_group_normalization
|
||||||
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
|
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
|
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue