From 2ea2bc39489cd849e2d606b48be324da2e62f7e7 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 14 Jun 2024 21:48:53 +0530 Subject: [PATCH] [ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 72 ++++++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 25 +++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 87afc46bd..d3cffd89c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1818,6 +1818,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); + + // Since the support for `stash_type` arg does not exist in + // the torch op so we just check for the stash_type to be same + // as the input dtype since that won't require us to do any + // input type conversion and hence can be supported. + auto xType = cast(x.getType()); + std::optional stashTypeIntTorch = + onnxDtypeIntToTorchDtypeInt(stashType); + if (!stashTypeIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given stash_type"); + + FailureOr stashDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + (torch_upstream::ScalarType)stashTypeIntTorch.value()); + if (failed(stashDtype)) + return failure(); + if (*stashDtype != xType.getOptionalDtype()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: stash_type should be same " + "as the input dtype"); + Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); @@ -1826,7 +1848,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rank = *maybeRank; SmallVector normalized; axis = Torch::toPositiveDim(axis, rank); - auto xType = cast(x.getType()); if (!xType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input (X) to have sizes"); @@ -2444,4 +2465,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( paddingList); return success(); }); + patterns.onOp( + "GroupNormalization", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, scale, bias; + int64_t numGroups, stashType; + float epsilon; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(scale, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(numGroups, "num_groups") || + binder.f32FloatAttr(epsilon, "epsilon", 1e-5) || + binder.s64IntegerAttr(stashType, "stash_type", 1)) + return failure(); + + // Since the support for `stash_type` arg does not exist in + // the torch op so we just check for the stash_type to be same + // as the input dtype since that won't require us to do any + // input type conversion and hence can be supported. + std::optional stashTypeIntTorch = + onnxDtypeIntToTorchDtypeInt(stashType); + if (!stashTypeIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given stash_type"); + + FailureOr stashDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + (torch_upstream::ScalarType)stashTypeIntTorch.value()); + if (failed(stashDtype)) + return failure(); + auto inputDtype = + cast(input.getType()).getOptionalDtype(); + if (*stashDtype != inputDtype) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: stash_type != input dtype"); + + Value cstEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr((double)epsilon)); + Value cstNumGroups = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(numGroups)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon, + /*cudnn_enabled=*/cstFalse); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index fc79f88b1..72af7eca9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1292,3 +1292,28 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1 %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_group_normalization +func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> + return %0 : !torch.vtensor<[3,4,2,2],f32> +} + +// ----- + +func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> + return %0 : !torch.vtensor<[3,4,2,2],f32> +}