mirror of https://github.com/llvm/torch-mlir
MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658)
This commit adds the shape info for the tensors created during the decomposition of GroupNorm op. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3636/merge
parent
a980130676
commit
fcc5f444cd
|
@ -6233,7 +6233,6 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
|||
LogicalResult matchAndRewrite(AtenGroupNormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
Value input = op.getInput();
|
||||
Value weight = op.getWeight();
|
||||
|
@ -6241,11 +6240,23 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
|||
Value numGroups = op.getNumGroups();
|
||||
Value eps = op.getEps();
|
||||
|
||||
int64_t numGroupsInt;
|
||||
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: num_groups must be a constant int");
|
||||
|
||||
Value cstZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "input should have sizes.");
|
||||
|
||||
SmallVector<int64_t> baseTypeSizes{inputType.getSizes()[0], numGroupsInt};
|
||||
auto baseType = inputType.getWithSizesAndDtype(
|
||||
baseTypeSizes, inputType.getOptionalDtype());
|
||||
|
||||
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
||||
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
||||
|
@ -6299,7 +6310,6 @@ class DecomposeAtenNativeGroupNormOp
|
|||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
|
||||
// GroupNorm requires the channel dimension (C) to be exactly divisible by
|
||||
// the number of groups.
|
||||
|
@ -6313,12 +6323,34 @@ class DecomposeAtenNativeGroupNormOp
|
|||
"the number of groups"));
|
||||
|
||||
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
|
||||
int64_t numGroupsInt;
|
||||
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: num_groups must be a constant int");
|
||||
|
||||
SmallVector<Value> newShape;
|
||||
SmallVector<int64_t> inputShapeInt{inputType.getSizes()};
|
||||
SmallVector<int64_t> reshapeInputShape{inputShapeInt[0], numGroupsInt};
|
||||
int64_t reshapeInputLastDim = 1;
|
||||
for (size_t i = 1; i < inputShapeInt.size(); i++) {
|
||||
if (inputShapeInt[i] == Torch::kUnknownSize) {
|
||||
reshapeInputLastDim = Torch::kUnknownSize;
|
||||
break;
|
||||
}
|
||||
reshapeInputLastDim *= inputShapeInt[i];
|
||||
}
|
||||
reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize
|
||||
? reshapeInputLastDim
|
||||
: reshapeInputLastDim / numGroupsInt;
|
||||
reshapeInputShape.push_back(reshapeInputLastDim);
|
||||
|
||||
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
|
||||
newShape.push_back(numGroups);
|
||||
newShape.push_back(cstNegtiveOne);
|
||||
Type reshapeInputType = inputType.getWithSizesAndDtype(
|
||||
reshapeInputShape, inputType.getOptionalDtype());
|
||||
Value reshapedInput = rewriter.create<AtenViewOp>(
|
||||
loc, baseType, input,
|
||||
loc, reshapeInputType, input,
|
||||
rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(IntType::get(context)), newShape));
|
||||
|
||||
|
@ -6327,21 +6359,28 @@ class DecomposeAtenNativeGroupNormOp
|
|||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
ArrayRef<Value>{cstNegtiveOne});
|
||||
auto mean = rewriter.create<AtenMeanDimOp>(
|
||||
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
|
||||
/*dtype=*/none);
|
||||
auto var = rewriter.create<AtenVarDimOp>(
|
||||
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
|
||||
/*keepdim=*/cstTrue);
|
||||
|
||||
reshapeInputShape[2] = 1;
|
||||
Type reductionType = inputType.getWithSizesAndDtype(
|
||||
reshapeInputShape, inputType.getOptionalDtype());
|
||||
auto mean =
|
||||
rewriter.create<AtenMeanDimOp>(loc, reductionType, reshapedInput,
|
||||
/*dims=*/dimList, /*keepdim=*/cstTrue,
|
||||
/*dtype=*/none);
|
||||
auto var =
|
||||
rewriter.create<AtenVarDimOp>(loc, reductionType, reshapedInput,
|
||||
/*dims=*/dimList, /*unbiased=*/cstFalse,
|
||||
/*keepdim=*/cstTrue);
|
||||
|
||||
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
|
||||
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
|
||||
/*alpha=*/cstOne);
|
||||
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
|
||||
auto varPlusEps =
|
||||
rewriter.create<AtenAddScalarOp>(loc, reductionType, var, eps,
|
||||
/*alpha=*/cstOne);
|
||||
auto invStd = rewriter.create<AtenRsqrtOp>(loc, reductionType, varPlusEps);
|
||||
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
|
||||
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
|
||||
auto normalizedOutput =
|
||||
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
|
||||
loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne);
|
||||
auto normalizedOutput = rewriter.create<AtenMulTensorOp>(
|
||||
loc, reshapeInputType, inputSubMean, invStd);
|
||||
|
||||
// Reshape normalized output back to the original input shape
|
||||
auto inputShape = rewriter.create<AtenSizeOp>(
|
||||
|
@ -6352,22 +6391,26 @@ class DecomposeAtenNativeGroupNormOp
|
|||
// Apply weight and bias if they are not None
|
||||
// Reshape weight and bias to C,1,1,...
|
||||
SmallVector<Value> viewShape = {channel};
|
||||
SmallVector<int64_t> viewShapeInt{inputShapeInt[1]};
|
||||
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
|
||||
viewShape.push_back(cstOne);
|
||||
viewShapeInt.push_back(1);
|
||||
}
|
||||
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), viewShape);
|
||||
|
||||
Type viewType = inputType.getWithSizesAndDtype(
|
||||
viewShapeInt, inputType.getOptionalDtype());
|
||||
Value groupNormOutput = reshapedOutput;
|
||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||
auto weightReshaped = rewriter.create<AtenViewOp>(
|
||||
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
||||
loc, viewType, weight, /*shape=*/viewShapeSizeList);
|
||||
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputType, groupNormOutput, weightReshaped);
|
||||
}
|
||||
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||
auto biasReshaped = rewriter.create<AtenViewOp>(
|
||||
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
||||
loc, viewType, bias, /*shape=*/viewShapeSizeList);
|
||||
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||
loc, inputType, groupNormOutput, biasReshaped,
|
||||
/*alpha=*/cstOne);
|
||||
|
|
|
@ -1626,25 +1626,25 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_group_normalization
|
||||
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue