mirror of https://github.com/llvm/torch-mlir
[ONNX] Add onnx parser for LpPool operator (#3449)
Similar to https://github.com/llvm/torch-mlir/pull/3435 Solves https://github.com/nod-ai/SHARK-Turbine/issues/728pull/3461/merge
parent
6f94c7b0aa
commit
04c6479350
|
@ -1687,6 +1687,122 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"LpPool", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return failure();
|
||||
if (autoPad != "NOTSET") {
|
||||
// TODO: Add support for `auto_pad` != "NOTSET"
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
}
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t ceilMode, p;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
|
||||
binder.s64IntegerAttr(p, "p", 2) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
// Determine the rank of input tensor.
|
||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
unsigned rank = *maybeRank;
|
||||
// only 1D, 2D and 3D LpPool is supported.
|
||||
if (rank > 5 or rank < 3) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> kernel, padding, strides, dilations;
|
||||
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
|
||||
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}) ||
|
||||
binder.s64IntegerArrayAttr(padding, "pads", defaultPadding) ||
|
||||
binder.s64IntegerArrayAttr(
|
||||
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1)) ||
|
||||
binder.s64IntegerArrayAttr(dilations, "dilations", {})) {
|
||||
return failure();
|
||||
}
|
||||
if (kernel.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "kernel list size does not match the number of axes");
|
||||
}
|
||||
if (padding.size() != 2 * (rank - 2)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"padding list size does not match twice the number of axes");
|
||||
}
|
||||
if (strides.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
}
|
||||
if (dilations.size() > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "dilation is not supported by torch.aten.avgpool op "
|
||||
"and therefore it is not supported for LpPool.");
|
||||
}
|
||||
|
||||
SmallVector<Value> cstKernel, cstPadding, cstStrides;
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value numElements = cstOne;
|
||||
for (int64_t i : kernel) {
|
||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
numElements = rewriter.create<Torch::AtenMulOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
cstKernel.back(), numElements);
|
||||
}
|
||||
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstKernel);
|
||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
Value cstCeilMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||
// onnx lp pool doesn't have countIncludePad attribute but set it to
|
||||
// true so that in 1D case numElements is correctly undoes divison. For
|
||||
// 2D/3D case, division is avoided by divison_override.
|
||||
Value cstCountIncludePad =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
Value pv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), p));
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||
Value abs = rewriter.create<Torch::AtenAbsOp>(binder.getLoc(),
|
||||
inputTensorType, operand);
|
||||
Value pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
|
||||
binder.getLoc(), inputTensorType, abs, pv);
|
||||
Value avgPool;
|
||||
if (rank == 3) {
|
||||
avgPool = rewriter.create<Torch::AtenAvgPool1dOp>(
|
||||
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad);
|
||||
avgPool = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
binder.getLoc(), resultType, avgPool, numElements);
|
||||
} else if (rank == 4) {
|
||||
avgPool = rewriter.create<Torch::AtenAvgPool2dOp>(
|
||||
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
/*divisor_override=*/cstOne);
|
||||
} else { // rank == 5
|
||||
avgPool = rewriter.create<Torch::AtenAvgPool3dOp>(
|
||||
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
/*divisor_override=*/cstOne);
|
||||
}
|
||||
Value invP = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(double{1.0 / p}));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorScalarOp>(
|
||||
binder.op, resultType, avgPool, invP);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"LayerNormalization", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -274,6 +274,62 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lppool_2d
|
||||
func.func @test_lppool_2d(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NE1:.*]] = torch.aten.mul %[[I2_1]], %[[NE]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CIP:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[P:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32,32],f32> -> !torch.vtensor<[1,3,32,32],f32>
|
||||
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32,32],f32>, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
|
||||
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool2d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]], %[[I1]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,31,31],f32>
|
||||
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: torch.aten.pow.Tensor_Scalar %[[AVG]], %[[INVP]] : !torch.vtensor<[1,3,31,31],f32>, !torch.float -> !torch.vtensor<[1,3,31,31],f32>
|
||||
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31,31],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lppool_1d
|
||||
func.func @test_lppool_1d(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CIP:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[P:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32],f32> -> !torch.vtensor<[1,3,32],f32>
|
||||
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32],f32>, !torch.int -> !torch.vtensor<[1,3,32],f32>
|
||||
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool1d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32>
|
||||
// CHECK: %[[POW_0:.*]] = torch.aten.mul.Scalar %[[AVG]], %[[NE]] : !torch.vtensor<[1,3,31],f32>, !torch.int -> !torch.vtensor<[1,3,31],f32>
|
||||
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: torch.aten.pow.Tensor_Scalar %[[POW_0]], %[[INVP]] : !torch.vtensor<[1,3,31],f32>, !torch.float -> !torch.vtensor<[1,3,31],f32>
|
||||
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL : func.func @test_layer_norm
|
||||
func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>)
|
||||
attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
|
|
Loading…
Reference in New Issue