mirror of https://github.com/llvm/torch-mlir
[ONNX] Fix LpNormalization Lowering (#3521)
The LpNormalization lowering was previously just computing the norm, which is incorrect. This computes the norm then divides the input tensor by it's norm. I've tested this against some simple onnx models locally. I'll look into adding a test case for this in an external test suite.pull/2520/merge
parent
0b46d1110a
commit
dcb48dd46c
|
@ -2674,8 +2674,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp("LpNormalization", 1,
|
||||||
"LpNormalization", 1,
|
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
int64_t axis, p;
|
int64_t axis, p;
|
||||||
|
@ -2693,15 +2692,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
loc, rewriter.getI64IntegerAttr(p));
|
loc, rewriter.getI64IntegerAttr(p));
|
||||||
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
|
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
loc, rewriter.getBoolAttr(true));
|
loc, rewriter.getBoolAttr(true));
|
||||||
Value axisPrimList = rewriter.create<Torch::PrimListConstructOp>(
|
Value axisPrimList =
|
||||||
|
rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
rewriter.getType<Torch::ListType>(
|
rewriter.getType<Torch::ListType>(
|
||||||
rewriter.getType<Torch::IntType>()),
|
rewriter.getType<Torch::IntType>()),
|
||||||
llvm::ArrayRef<Value>{cstAxis});
|
llvm::ArrayRef<Value>{cstAxis});
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenNormScalarOptDimOp>(
|
SmallVector<int64_t> normSizes(resultType.getSizes());
|
||||||
binder.op, resultType, input, cstP, axisPrimList, cstKeepDim);
|
int64_t rank = normSizes.size();
|
||||||
|
axis = axis % rank;
|
||||||
|
axis = (axis < 0) ? axis + rank : axis;
|
||||||
|
normSizes[axis] = 1;
|
||||||
|
auto normType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
normSizes, resultType.getDtype());
|
||||||
|
Value norm = rewriter.create<Torch::AtenNormScalarOptDimOp>(
|
||||||
|
loc, normType, input, cstP, axisPrimList, cstKeepDim);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
|
||||||
|
binder.op, resultType, input, norm);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
|
|
|
@ -1423,15 +1423,16 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_lpnormalization
|
// CHECK-LABEL: @test_lpnormalization
|
||||||
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
|
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[OUT:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
|
// CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
|
||||||
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,1,6,7],f32>
|
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32>
|
||||||
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32>
|
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,1,6,7],f32>
|
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5,6,7],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue