Add onnx op LRN lowering (#3432)

This commit adds support for lowering
Onnx LRN op to aten.
pull/3461/merge
Manupa Karunaratne 2024-06-14 17:44:43 +01:00 committed by GitHub
parent 09c988046c
commit d2b663ece7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 248 additions and 2 deletions

View File

@ -34,7 +34,7 @@ namespace mlir::torch::onnx_c {
Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);
ArrayRef<int64_t> cstInput);
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);

View File

@ -1945,6 +1945,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand, constAlpha);
return success();
});
patterns.onOp(
"LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t size;
float alpha, beta, bias;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(size, "size", 2) ||
binder.f32FloatAttr(alpha, "alpha", 0.0001f) ||
binder.f32FloatAttr(beta, "beta", 0.75f) ||
binder.f32FloatAttr(bias, "bias", 1.0f))
return failure();
Type dtype = resultType.getOptionalDtype();
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(beta));
Value constBias = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(bias));
// Please refer to the operator description
// for more info on the lowering
// https://onnx.ai/onnx/operators/onnx__LRN.html
// squared = operand^2
Location loc = binder.getLoc();
Torch::ValueTensorType inTy =
cast<Torch::ValueTensorType>(operand.getType());
Value sqOperand = rewriter.create<Torch::AtenMulTensorOp>(
loc, inTy, operand, operand);
// view it as n x 1 x c x d0 x d..
if (!inTy.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input to have sizes");
}
ArrayRef<int64_t> inTyShape = inTy.getSizes();
if (inTyShape.size() < 3) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: the input dimensions should be >= 3");
}
if (inTyShape[1] == Torch::kUnknownSize) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: the second dimension size must be "
"statically known");
}
SmallVector<int64_t, 5> viewShapeInt{inTyShape[0], 1, inTyShape[1],
inTyShape[2], Torch::kUnknownSize};
Torch::ValueTensorType reshapeType =
rewriter.getType<Torch::ValueTensorType>(viewShapeInt, dtype);
Value viewShapeListVal =
createConstantIntList(binder, rewriter, viewShapeInt);
auto view = rewriter.create<Torch::AtenViewOp>(
loc, reshapeType, sqOperand, viewShapeListVal);
// padding
int64_t highPad = (size - 1) / 2;
int64_t lowPad = (size - 1) - highPad;
SmallVector<int64_t> paddingInt{0, 0, 0, 0, lowPad, highPad};
auto constPadVal = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(0.0));
Value paddingListVal =
createConstantIntList(binder, rewriter, paddingInt);
SmallVector<int64_t, 5> paddedShapeInt = viewShapeInt;
paddedShapeInt[2] += size - 1;
Torch::ValueTensorType paddedType =
rewriter.getType<Torch::ValueTensorType>(paddedShapeInt, dtype);
auto padded = rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddedType, view, paddingListVal, constPadVal);
// avg_pool3d
SmallVector<int64_t, 3> kernelSize{size, 1, 1};
Value kernelSizeList =
createConstantIntList(binder, rewriter, kernelSize);
SmallVector<int64_t, 3> strides{1, 1, 1};
Value stridesList = createConstantIntList(binder, rewriter, strides);
SmallVector<int64_t, 3> padding{0, 0, 0};
Value paddingList = createConstantIntList(binder, rewriter, padding);
auto cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto cstCountIncludeMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
// Output of pooling is same reshape(view) type because
// of the padding done on the dimensions being pooled.
auto pool = rewriter.create<Torch::AtenAvgPool3dOp>(
loc, reshapeType, padded, kernelSizeList, stridesList, paddingList,
cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone);
// squeeze
auto one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<int64_t, 5> squeezeShapeInt{
viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]};
Torch::ValueTensorType squeezeType =
rewriter.getType<Torch::ValueTensorType>(squeezeShapeInt, dtype);
auto squeeze = rewriter.create<Torch::AtenSqueezeDimOp>(
loc, squeezeType, pool, one);
// view as input Type
Value intTyShapeList =
createConstantIntList(binder, rewriter, inTyShape);
auto viewAsInput = rewriter.create<Torch::AtenViewOp>(
loc, inTy, squeeze, intTyShapeList);
// mul + add + pow + div
auto mul = rewriter.create<Torch::AtenMulScalarOp>(
loc, resultType, viewAsInput, constAlpha);
auto add = rewriter.create<Torch::AtenAddScalarOp>(loc, resultType, mul,
constBias, one);
auto pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
loc, resultType, add, constBeta);
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, operand, pow);
return success();
});
patterns.onOp(
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -16,7 +16,7 @@ using namespace mlir::torch::onnx_c;
Value mlir::torch::onnx_c::createConstantIntList(
OpBinder binder, ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput) {
ArrayRef<int64_t> cstInput) {
SmallVector<Value> cstValue;
for (int64_t i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(

View File

@ -366,6 +366,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
// -----
// CHECK-LABEL: func.func @test_lrn_default
func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
// CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]]
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
// CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32>
return %0 : !torch.vtensor<[20,10,3,50],f32>
}
// -----
// CHECK-LABEL: func.func @test_lrn_with_optionals
func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
// CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]]
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
// CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
// CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%none = torch.constant.none
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32>
return %0 : !torch.vtensor<[13,19,100,200],f32>
}
// -----
// CHECK-LABEL: @test_matmul_2d
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>