[ONNX] Fix LpNormalization Lowering (#3521)

The LpNormalization lowering was previously just computing the norm,
which is incorrect. This computes the norm then divides the input tensor
by it's norm.

I've tested this against some simple onnx models locally. I'll look into
adding a test case for this in an external test suite.
pull/2520/merge
zjgarvey 2024-07-09 13:42:26 -07:00 committed by GitHub
parent 0b46d1110a
commit dcb48dd46c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 32 deletions

View File

@ -2674,8 +2674,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success(); return success();
}); });
patterns.onOp( patterns.onOp("LpNormalization", 1,
"LpNormalization", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
int64_t axis, p; int64_t axis, p;
@ -2693,15 +2692,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
loc, rewriter.getI64IntegerAttr(p)); loc, rewriter.getI64IntegerAttr(p));
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>( Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true)); loc, rewriter.getBoolAttr(true));
Value axisPrimList = rewriter.create<Torch::PrimListConstructOp>( Value axisPrimList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), binder.getLoc(),
rewriter.getType<Torch::ListType>( rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()), rewriter.getType<Torch::IntType>()),
llvm::ArrayRef<Value>{cstAxis}); llvm::ArrayRef<Value>{cstAxis});
rewriter.replaceOpWithNewOp<Torch::AtenNormScalarOptDimOp>( SmallVector<int64_t> normSizes(resultType.getSizes());
binder.op, resultType, input, cstP, axisPrimList, cstKeepDim); int64_t rank = normSizes.size();
axis = axis % rank;
axis = (axis < 0) ? axis + rank : axis;
normSizes[axis] = 1;
auto normType = rewriter.getType<Torch::ValueTensorType>(
normSizes, resultType.getDtype());
Value norm = rewriter.create<Torch::AtenNormScalarOptDimOp>(
loc, normType, input, cstP, axisPrimList, cstKeepDim);
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, input, norm);
return success(); return success();
}); });
patterns.onOp( patterns.onOp(

View File

@ -1423,15 +1423,16 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3
// ----- // -----
// CHECK-LABEL: @test_lpnormalization // CHECK-LABEL: @test_lpnormalization
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[CST2:.*]] = torch.constant.int 2 // CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2 // CHECK: %[[CST2_0:.*]] = torch.constant.int 2
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int> // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32> // CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,1,6,7],f32> // CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32>
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32>
return %0 : !torch.vtensor<[3,4,1,6,7],f32> %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32>
return %0 : !torch.vtensor<[3,4,5,6,7],f32>
} }
// ----- // -----