[MLIR][ONNX] Add OnnxToTorch support for GlobalMaxPool Op (#3232)

https://github.com/nod-ai/SHARK-Turbine/issues/658

---------

Co-authored-by: root <root@i32b01216.sqa.eu95>
pull/3346/head
NeverRaR 2024-05-14 18:25:39 +08:00 committed by GitHub
parent 20f312853c
commit 26b78285bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 0 deletions

View File

@ -1265,6 +1265,83 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
} }
return failure(); return failure();
}); });
patterns.onOp(
"GlobalMaxPool", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
unsigned inputRank = inputShape.size();
if (!resultType || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected result type having sizes");
}
SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
for (unsigned i = 2; i < inputRank; i++) {
if (inputShape[i] == Torch::kUnknownSize) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), operand, dim);
cstKernel.push_back(inputDimSize);
} else {
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i])));
}
cstPadding.push_back(cstZero);
cstDilations.push_back(cstOne);
cstStrides.push_back(cstOne);
}
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstKernel);
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
if (inputRank == 3) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList,
paddingList, dilationsList, cstCeilMode);
return success();
} else if (inputRank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList,
paddingList, dilationsList, cstCeilMode);
return success();
} else if (inputRank == 5) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool3dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList,
paddingList, dilationsList, cstCeilMode);
return success();
}
return failure();
});
patterns.onOp( patterns.onOp(
"LayerNormalization", 17, "LayerNormalization", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -743,6 +743,42 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
// ----- // -----
// CHECK-LABEL: @test_globalmaxpool
func.func @test_globalmaxpool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[C5_0:.*]] = torch.constant.int 5
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,1,1],f32>
%0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32>
return %0 : !torch.vtensor<[1,3,1,1],f32>
}
// -----
// CHECK-LABEL: @test_globalmaxpool_precomputed
func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1,1],f32>
%0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_max_example // CHECK-LABEL: func.func @test_max_example
func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>