mirror of https://github.com/llvm/torch-mlir
Add onnx op LRN lowering (#3432)
This commit adds support for lowering Onnx LRN op to aten.pull/3461/merge
parent
09c988046c
commit
d2b663ece7
|
@ -34,7 +34,7 @@ namespace mlir::torch::onnx_c {
|
|||
|
||||
Value createConstantIntList(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput);
|
||||
ArrayRef<int64_t> cstInput);
|
||||
|
||||
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);
|
||||
|
||||
|
|
|
@ -1945,6 +1945,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, operand, constAlpha);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t size;
|
||||
float alpha, beta, bias;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(size, "size", 2) ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 0.0001f) ||
|
||||
binder.f32FloatAttr(beta, "beta", 0.75f) ||
|
||||
binder.f32FloatAttr(bias, "bias", 1.0f))
|
||||
return failure();
|
||||
Type dtype = resultType.getOptionalDtype();
|
||||
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(alpha));
|
||||
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(beta));
|
||||
Value constBias = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(bias));
|
||||
// Please refer to the operator description
|
||||
// for more info on the lowering
|
||||
// https://onnx.ai/onnx/operators/onnx__LRN.html
|
||||
|
||||
// squared = operand^2
|
||||
Location loc = binder.getLoc();
|
||||
Torch::ValueTensorType inTy =
|
||||
cast<Torch::ValueTensorType>(operand.getType());
|
||||
Value sqOperand = rewriter.create<Torch::AtenMulTensorOp>(
|
||||
loc, inTy, operand, operand);
|
||||
// view it as n x 1 x c x d0 x d..
|
||||
if (!inTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expected input to have sizes");
|
||||
}
|
||||
ArrayRef<int64_t> inTyShape = inTy.getSizes();
|
||||
if (inTyShape.size() < 3) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unsupported: the input dimensions should be >= 3");
|
||||
}
|
||||
if (inTyShape[1] == Torch::kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unsupported: the second dimension size must be "
|
||||
"statically known");
|
||||
}
|
||||
SmallVector<int64_t, 5> viewShapeInt{inTyShape[0], 1, inTyShape[1],
|
||||
inTyShape[2], Torch::kUnknownSize};
|
||||
Torch::ValueTensorType reshapeType =
|
||||
rewriter.getType<Torch::ValueTensorType>(viewShapeInt, dtype);
|
||||
Value viewShapeListVal =
|
||||
createConstantIntList(binder, rewriter, viewShapeInt);
|
||||
auto view = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, reshapeType, sqOperand, viewShapeListVal);
|
||||
// padding
|
||||
int64_t highPad = (size - 1) / 2;
|
||||
int64_t lowPad = (size - 1) - highPad;
|
||||
SmallVector<int64_t> paddingInt{0, 0, 0, 0, lowPad, highPad};
|
||||
auto constPadVal = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(0.0));
|
||||
Value paddingListVal =
|
||||
createConstantIntList(binder, rewriter, paddingInt);
|
||||
SmallVector<int64_t, 5> paddedShapeInt = viewShapeInt;
|
||||
paddedShapeInt[2] += size - 1;
|
||||
Torch::ValueTensorType paddedType =
|
||||
rewriter.getType<Torch::ValueTensorType>(paddedShapeInt, dtype);
|
||||
auto padded = rewriter.create<Torch::AtenConstantPadNdOp>(
|
||||
loc, paddedType, view, paddingListVal, constPadVal);
|
||||
// avg_pool3d
|
||||
SmallVector<int64_t, 3> kernelSize{size, 1, 1};
|
||||
Value kernelSizeList =
|
||||
createConstantIntList(binder, rewriter, kernelSize);
|
||||
SmallVector<int64_t, 3> strides{1, 1, 1};
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
SmallVector<int64_t, 3> padding{0, 0, 0};
|
||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||
auto cstCeilMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
auto cstCountIncludeMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
// Output of pooling is same reshape(view) type because
|
||||
// of the padding done on the dimensions being pooled.
|
||||
auto pool = rewriter.create<Torch::AtenAvgPool3dOp>(
|
||||
loc, reshapeType, padded, kernelSizeList, stridesList, paddingList,
|
||||
cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone);
|
||||
// squeeze
|
||||
auto one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<int64_t, 5> squeezeShapeInt{
|
||||
viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]};
|
||||
Torch::ValueTensorType squeezeType =
|
||||
rewriter.getType<Torch::ValueTensorType>(squeezeShapeInt, dtype);
|
||||
auto squeeze = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
loc, squeezeType, pool, one);
|
||||
// view as input Type
|
||||
Value intTyShapeList =
|
||||
createConstantIntList(binder, rewriter, inTyShape);
|
||||
auto viewAsInput = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, inTy, squeeze, intTyShapeList);
|
||||
// mul + add + pow + div
|
||||
auto mul = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
loc, resultType, viewAsInput, constAlpha);
|
||||
auto add = rewriter.create<Torch::AtenAddScalarOp>(loc, resultType, mul,
|
||||
constBias, one);
|
||||
auto pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
|
||||
loc, resultType, add, constBeta);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
|
||||
binder.op, resultType, operand, pow);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -16,7 +16,7 @@ using namespace mlir::torch::onnx_c;
|
|||
|
||||
Value mlir::torch::onnx_c::createConstantIntList(
|
||||
OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput) {
|
||||
ArrayRef<int64_t> cstInput) {
|
||||
SmallVector<Value> cstValue;
|
||||
for (int64_t i : cstInput) {
|
||||
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
|
|
|
@ -366,6 +366,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lrn_default
|
||||
func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
|
||||
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
|
||||
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
|
||||
|
||||
// CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
|
||||
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
|
||||
// CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
|
||||
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]]
|
||||
|
||||
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
|
||||
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]
|
||||
|
||||
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
|
||||
|
||||
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]
|
||||
|
||||
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
|
||||
|
||||
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
|
||||
|
||||
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
|
||||
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
|
||||
|
||||
// CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
|
||||
// CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
|
||||
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
|
||||
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]
|
||||
|
||||
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
|
||||
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
|
||||
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
|
||||
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
|
||||
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
|
||||
// CHECK: return %[[OUTPUT]]
|
||||
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32>
|
||||
return %0 : !torch.vtensor<[20,10,3,50],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lrn_with_optionals
|
||||
func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
|
||||
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
|
||||
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
|
||||
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
|
||||
|
||||
// CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
|
||||
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
|
||||
// CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
|
||||
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
|
||||
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]]
|
||||
|
||||
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
|
||||
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]
|
||||
|
||||
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
|
||||
|
||||
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
|
||||
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]
|
||||
|
||||
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
|
||||
|
||||
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
|
||||
|
||||
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
|
||||
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
|
||||
|
||||
// CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
|
||||
// CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
|
||||
// CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
|
||||
// CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
|
||||
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]
|
||||
|
||||
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
|
||||
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
|
||||
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
|
||||
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
|
||||
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
|
||||
// CHECK: return %[[OUTPUT]]
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32>
|
||||
return %0 : !torch.vtensor<[13,19,100,200],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_matmul_2d
|
||||
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
|
||||
|
|
Loading…
Reference in New Issue