From dcb48dd46ccd10b6cccd6396a5e495d21a9c4d52 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:42:26 -0700 Subject: [PATCH] [ONNX] Fix LpNormalization Lowering (#3521) The LpNormalization lowering was previously just computing the norm, which is incorrect. This computes the norm then divides the input tensor by it's norm. I've tested this against some simple onnx models locally. I'll look into adding a test case for this in an external test suite. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 63 +++++++++++-------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 11 ++-- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0415a562d..2b1bec3f9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2674,36 +2674,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); - patterns.onOp( - "LpNormalization", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - int64_t axis, p; - Value input; - if (binder.tensorOperand(input) || - binder.s64IntegerAttr(axis, "axis", -1) || - binder.s64IntegerAttr(p, "p", 2) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp("LpNormalization", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t axis, p; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.s64IntegerAttr(p, "p", 2) || + binder.tensorResultType(resultType)) + return failure(); - auto loc = binder.getLoc(); - Value cstAxis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axis)); - Value cstP = rewriter.create( - loc, rewriter.getI64IntegerAttr(p)); - Value cstKeepDim = rewriter.create( - loc, rewriter.getBoolAttr(true)); - Value axisPrimList = rewriter.create( - binder.getLoc(), - rewriter.getType( - rewriter.getType()), - llvm::ArrayRef{cstAxis}); + auto loc = binder.getLoc(); + Value cstAxis = rewriter.create( + loc, rewriter.getI64IntegerAttr(axis)); + Value cstP = rewriter.create( + loc, rewriter.getI64IntegerAttr(p)); + Value cstKeepDim = rewriter.create( + loc, rewriter.getBoolAttr(true)); + Value axisPrimList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + llvm::ArrayRef{cstAxis}); - rewriter.replaceOpWithNewOp( - binder.op, resultType, input, cstP, axisPrimList, cstKeepDim); + SmallVector normSizes(resultType.getSizes()); + int64_t rank = normSizes.size(); + axis = axis % rank; + axis = (axis < 0) ? axis + rank : axis; + normSizes[axis] = 1; + auto normType = rewriter.getType( + normSizes, resultType.getDtype()); + Value norm = rewriter.create( + loc, normType, input, cstP, axisPrimList, cstKeepDim); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, norm); + return success(); + }); patterns.onOp( "MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // TODO: Add support for `output_shape` arg. diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 74713552b..38f81f4c0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1423,15 +1423,16 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3 // ----- // CHECK-LABEL: @test_lpnormalization -func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[CST2:.*]] = torch.constant.int 2 // CHECK: %[[CST2_0:.*]] = torch.constant.int 2 // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list - // CHECK: %[[OUT:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32> - // CHECK: return %[[OUT]] : !torch.vtensor<[3,4,1,6,7],f32> - %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> - return %0 : !torch.vtensor<[3,4,1,6,7],f32> + // CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32> + // CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32> + %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> + return %0 : !torch.vtensor<[3,4,5,6,7],f32> } // -----