[ONNX] Fix LpNormalization Lowering (#3521)

The LpNormalization lowering was previously just computing the norm,
which is incorrect. This computes the norm then divides the input tensor
by it's norm.

I've tested this against some simple onnx models locally. I'll look into
adding a test case for this in an external test suite.
pull/2520/merge
zjgarvey 2024-07-09 13:42:26 -07:00 committed by GitHub
parent 0b46d1110a
commit dcb48dd46c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 32 deletions

View File

@ -2674,36 +2674,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
});
patterns.onOp(
"LpNormalization", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
int64_t axis, p;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(axis, "axis", -1) ||
binder.s64IntegerAttr(p, "p", 2) ||
binder.tensorResultType(resultType))
return failure();
patterns.onOp("LpNormalization", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
int64_t axis, p;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(axis, "axis", -1) ||
binder.s64IntegerAttr(p, "p", 2) ||
binder.tensorResultType(resultType))
return failure();
auto loc = binder.getLoc();
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(axis));
Value cstP = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(p));
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true));
Value axisPrimList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
llvm::ArrayRef<Value>{cstAxis});
auto loc = binder.getLoc();
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(axis));
Value cstP = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(p));
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true));
Value axisPrimList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
llvm::ArrayRef<Value>{cstAxis});
rewriter.replaceOpWithNewOp<Torch::AtenNormScalarOptDimOp>(
binder.op, resultType, input, cstP, axisPrimList, cstKeepDim);
SmallVector<int64_t> normSizes(resultType.getSizes());
int64_t rank = normSizes.size();
axis = axis % rank;
axis = (axis < 0) ? axis + rank : axis;
normSizes[axis] = 1;
auto normType = rewriter.getType<Torch::ValueTensorType>(
normSizes, resultType.getDtype());
Value norm = rewriter.create<Torch::AtenNormScalarOptDimOp>(
loc, normType, input, cstP, axisPrimList, cstKeepDim);
return success();
});
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, input, norm);
return success();
});
patterns.onOp(
"MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// TODO: Add support for `output_shape` arg.

View File

@ -1423,15 +1423,16 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3
// -----
// CHECK-LABEL: @test_lpnormalization
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,1,6,7],f32>
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32>
return %0 : !torch.vtensor<[3,4,1,6,7],f32>
// CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32>
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32>
return %0 : !torch.vtensor<[3,4,5,6,7],f32>
}
// -----