mirror of https://github.com/llvm/torch-mlir
[ONNX] Fix LpNormalization Lowering (#3521)
The LpNormalization lowering was previously just computing the norm, which is incorrect. This computes the norm then divides the input tensor by it's norm. I've tested this against some simple onnx models locally. I'll look into adding a test case for this in an external test suite.pull/2520/merge
parent
0b46d1110a
commit
dcb48dd46c
|
@ -2674,36 +2674,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"LpNormalization", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t axis, p;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(axis, "axis", -1) ||
|
||||
binder.s64IntegerAttr(p, "p", 2) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
patterns.onOp("LpNormalization", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t axis, p;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(axis, "axis", -1) ||
|
||||
binder.s64IntegerAttr(p, "p", 2) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto loc = binder.getLoc();
|
||||
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(axis));
|
||||
Value cstP = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(p));
|
||||
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getBoolAttr(true));
|
||||
Value axisPrimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
llvm::ArrayRef<Value>{cstAxis});
|
||||
auto loc = binder.getLoc();
|
||||
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(axis));
|
||||
Value cstP = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(p));
|
||||
Value cstKeepDim = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getBoolAttr(true));
|
||||
Value axisPrimList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
llvm::ArrayRef<Value>{cstAxis});
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenNormScalarOptDimOp>(
|
||||
binder.op, resultType, input, cstP, axisPrimList, cstKeepDim);
|
||||
SmallVector<int64_t> normSizes(resultType.getSizes());
|
||||
int64_t rank = normSizes.size();
|
||||
axis = axis % rank;
|
||||
axis = (axis < 0) ? axis + rank : axis;
|
||||
normSizes[axis] = 1;
|
||||
auto normType = rewriter.getType<Torch::ValueTensorType>(
|
||||
normSizes, resultType.getDtype());
|
||||
Value norm = rewriter.create<Torch::AtenNormScalarOptDimOp>(
|
||||
loc, normType, input, cstP, axisPrimList, cstKeepDim);
|
||||
|
||||
return success();
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
|
||||
binder.op, resultType, input, norm);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// TODO: Add support for `output_shape` arg.
|
||||
|
|
|
@ -1423,15 +1423,16 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_lpnormalization
|
||||
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,1,6,7],f32>
|
||||
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1,6,7],f32>
|
||||
// CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32>
|
||||
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5,6,7],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue