[ONNX] Add onnx parser for LpPool operator (#3449)

Similar to https://github.com/llvm/torch-mlir/pull/3435

Solves https://github.com/nod-ai/SHARK-Turbine/issues/728
pull/3461/merge
Umang Yadav 2024-06-14 12:11:18 -04:00 committed by GitHub
parent 6f94c7b0aa
commit 04c6479350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 172 additions and 0 deletions

View File

@ -1687,6 +1687,122 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success(); return success();
}); });
patterns.onOp(
"LpPool", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}
Torch::ValueTensorType resultType;
Value operand;
int64_t ceilMode, p;
if (binder.tensorOperand(operand) ||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
binder.s64IntegerAttr(p, "p", 2) ||
binder.tensorResultType(resultType))
return failure();
// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
// only 1D, 2D and 3D LpPool is supported.
if (rank > 5 or rank < 3) {
return failure();
}
SmallVector<int64_t> kernel, padding, strides, dilations;
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}) ||
binder.s64IntegerArrayAttr(padding, "pads", defaultPadding) ||
binder.s64IntegerArrayAttr(
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1)) ||
binder.s64IntegerArrayAttr(dilations, "dilations", {})) {
return failure();
}
if (kernel.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "kernel list size does not match the number of axes");
}
if (padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op,
"padding list size does not match twice the number of axes");
}
if (strides.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
}
if (dilations.size() > 0) {
return rewriter.notifyMatchFailure(
binder.op, "dilation is not supported by torch.aten.avgpool op "
"and therefore it is not supported for LpPool.");
}
SmallVector<Value> cstKernel, cstPadding, cstStrides;
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value numElements = cstOne;
for (int64_t i : kernel) {
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
numElements = rewriter.create<Torch::AtenMulOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
cstKernel.back(), numElements);
}
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstKernel);
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides);
Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
// onnx lp pool doesn't have countIncludePad attribute but set it to
// true so that in 1D case numElements is correctly undoes divison. For
// 2D/3D case, division is avoided by divison_override.
Value cstCountIncludePad =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value pv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), p));
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
Value abs = rewriter.create<Torch::AtenAbsOp>(binder.getLoc(),
inputTensorType, operand);
Value pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
binder.getLoc(), inputTensorType, abs, pv);
Value avgPool;
if (rank == 3) {
avgPool = rewriter.create<Torch::AtenAvgPool1dOp>(
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
paddingList, cstCeilMode, cstCountIncludePad);
avgPool = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, avgPool, numElements);
} else if (rank == 4) {
avgPool = rewriter.create<Torch::AtenAvgPool2dOp>(
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
paddingList, cstCeilMode, cstCountIncludePad,
/*divisor_override=*/cstOne);
} else { // rank == 5
avgPool = rewriter.create<Torch::AtenAvgPool3dOp>(
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
paddingList, cstCeilMode, cstCountIncludePad,
/*divisor_override=*/cstOne);
}
Value invP = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(double{1.0 / p}));
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorScalarOp>(
binder.op, resultType, avgPool, invP);
return success();
});
patterns.onOp( patterns.onOp(
"LayerNormalization", 17, "LayerNormalization", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -274,6 +274,62 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.
// ----- // -----
// CHECK-LABEL: func.func @test_lppool_2d
func.func @test_lppool_2d(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
// CHECK: %[[NE1:.*]] = torch.aten.mul %[[I2_1]], %[[NE]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
// CHECK: %[[CIP:.*]] = torch.constant.bool true
// CHECK: %[[P:.*]] = torch.constant.int 2
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32,32],f32> -> !torch.vtensor<[1,3,32,32],f32>
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32,32],f32>, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool2d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]], %[[I1]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,31,31],f32>
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
// CHECK: torch.aten.pow.Tensor_Scalar %[[AVG]], %[[INVP]] : !torch.vtensor<[1,3,31,31],f32>, !torch.float -> !torch.vtensor<[1,3,31,31],f32>
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
return %0 : !torch.vtensor<[1,3,31,31],f32>
}
// -----
// CHECK-LABEL: func.func @test_lppool_1d
func.func @test_lppool_1d(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
// CHECK: %[[CIP:.*]] = torch.constant.bool true
// CHECK: %[[P:.*]] = torch.constant.int 2
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32],f32> -> !torch.vtensor<[1,3,32],f32>
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32],f32>, !torch.int -> !torch.vtensor<[1,3,32],f32>
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool1d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32>
// CHECK: %[[POW_0:.*]] = torch.aten.mul.Scalar %[[AVG]], %[[NE]] : !torch.vtensor<[1,3,31],f32>, !torch.int -> !torch.vtensor<[1,3,31],f32>
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
// CHECK: torch.aten.pow.Tensor_Scalar %[[POW_0]], %[[INVP]] : !torch.vtensor<[1,3,31],f32>, !torch.float -> !torch.vtensor<[1,3,31],f32>
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32>
return %0 : !torch.vtensor<[1,3,31],f32>
}
// -----
// CHECK-LABEL : func.func @test_layer_norm // CHECK-LABEL : func.func @test_layer_norm
func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>)
attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {