From 5862854bc8011a94a54edeb4fa278908e9eb2c2b Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 11 Jan 2024 00:57:04 -0800 Subject: [PATCH] [ONNX][TORCH-MLIR] LayerNorm (#2716) Layer Normalization using the torch.aten.native_layer_norm https://github.com/nod-ai/SHARK-Turbine/issues/325 --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 43 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 13 ++++++ 2 files changed, 56 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index ee1ae6bb6..fd7013afd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -410,6 +410,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); + patterns.onOp("LayerNormalization", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType Y_type; + Torch::ValueTensorType Mean_type; + Torch::ValueTensorType InvStdDev_type; + Value X; + Value Scale; + Value B; + int64_t axis; + float epsilon; + int64_t stash_type; + if (binder.tensorOperandAtIndex(X, 0) || + binder.tensorOperandAtIndex(Scale, 1) || + binder.tensorOperandAtIndex(B, 2) || + binder.tensorResultTypeAtIndex(Y_type, 0) || + binder.tensorResultTypeAtIndex(Mean_type, 1) || + binder.tensorResultTypeAtIndex(InvStdDev_type, 2) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || + binder.s64IntegerAttr(stash_type, "stash_type", 1)) + return failure(); + Value constEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(epsilon)); + unsigned rank = 1; + if(std::optional maybeRank = Torch::getTensorRank(X)) + rank = *maybeRank; + SmallVector normalized; + axis = Torch::toPositiveDim(axis, rank); + auto X_type = X.getType().cast(); + ArrayRef X_shape = X_type.getSizes(); + for (int64_t n = axis; n < rank ; n++) { + normalized.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(X_shape[n]))); + } + Value normalized_shape = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + normalized); + rewriter.replaceOpWithNewOp( + binder.op, Y_type, Mean_type, InvStdDev_type, X, normalized_shape, Scale, B, constEpsilon); + return success(); + }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e224ddfa2..07ddf3e59 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -116,6 +116,19 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch. // ----- +// CHECK-LABEL : func.func @test_layer_norm +func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %int3 = torch.constant.int 3 + // CHECK: %int4 = torch.constant.int 4 + // CHECK: %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + // CHECK: %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2 + %0:3 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_leaky_relu func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2