mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for GlobalMaxPool Op (#3232)
https://github.com/nod-ai/SHARK-Turbine/issues/658 --------- Co-authored-by: root <root@i32b01216.sqa.eu95>pull/3346/head
parent
20f312853c
commit
26b78285bf
|
@ -1265,6 +1265,83 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"GlobalMaxPool", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand;
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected input type having sizes");
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
|
unsigned inputRank = inputShape.size();
|
||||||
|
if (!resultType || !resultType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected result type having sizes");
|
||||||
|
}
|
||||||
|
SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
|
||||||
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
|
for (unsigned i = 2; i < inputRank; i++) {
|
||||||
|
if (inputShape[i] == Torch::kUnknownSize) {
|
||||||
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||||
|
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), operand, dim);
|
||||||
|
cstKernel.push_back(inputDimSize);
|
||||||
|
} else {
|
||||||
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i])));
|
||||||
|
}
|
||||||
|
cstPadding.push_back(cstZero);
|
||||||
|
cstDilations.push_back(cstOne);
|
||||||
|
cstStrides.push_back(cstOne);
|
||||||
|
}
|
||||||
|
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstKernel);
|
||||||
|
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstPadding);
|
||||||
|
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstDilations);
|
||||||
|
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstStrides);
|
||||||
|
Value cstCeilMode =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
|
||||||
|
if (inputRank == 3) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
|
||||||
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||||
|
paddingList, dilationsList, cstCeilMode);
|
||||||
|
return success();
|
||||||
|
} else if (inputRank == 4) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
||||||
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||||
|
paddingList, dilationsList, cstCeilMode);
|
||||||
|
return success();
|
||||||
|
} else if (inputRank == 5) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool3dOp>(
|
||||||
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||||
|
paddingList, dilationsList, cstCeilMode);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"LayerNormalization", 17,
|
"LayerNormalization", 17,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -743,6 +743,42 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_globalmaxpool
|
||||||
|
func.func @test_globalmaxpool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[C5_0:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,3,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_globalmaxpool_precomputed
|
||||||
|
func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_max_example
|
// CHECK-LABEL: func.func @test_max_example
|
||||||
func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
|
Loading…
Reference in New Issue