From af54d278203625b6e58f86ad51856b3412d2d301 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 3 Apr 2024 22:12:48 +0530 Subject: [PATCH] [MLIR][TORCH] Fix Onnx.TopK lowering (#3103) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 16 ++++--- projects/pt1/e2e_testing/xfail_sets.py | 7 --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 45 +++++++++++-------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 997be2430..1996aa91d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2120,20 +2120,20 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "TopK", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType Values_type, Indices_type; - Value X, K; + Value input, kValue; int64_t axis; bool largest, sorted; - if (binder.tensorOperandAtIndex(X, 0) || - binder.tensorOperandAtIndex(K, 1) || + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(kValue, 1) || binder.s64IntegerAttr(axis, "axis", -1) || binder.s64BoolAttr(largest, "largest", true) || binder.s64BoolAttr(sorted, "sorted", true) || binder.tensorResultTypeAtIndex(Values_type, 0) || binder.tensorResultTypeAtIndex(Indices_type, 1)) return failure(); - std::optional maybeRank = Torch::getTensorRank(X); + std::optional maybeRank = Torch::getTensorRank(input); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); @@ -2145,9 +2145,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(binder.getLoc(), largest); Value cstSorted = rewriter.create(binder.getLoc(), sorted); + Value kValueInt = rewriter.create( + binder.getLoc(), rewriter.getType(), kValue); rewriter.replaceOpWithNewOp( - binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest, - cstSorted); + binder.op, Values_type, Indices_type, input, kValueInt, cstAxis, + cstLargest, cstSorted); return success(); }); patterns.onOp("Sign", 9, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2407ce2b7..44594e9b6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2042,13 +2042,6 @@ ONNX_XFAIL_SET = { "SqueezeModule_broadcast", "SqueezeModule_static", - # Failure - onnx_lowering: onnx.TopK - "SortTensorDescending_basic", - "SortTensorInteger_basic", - "SortTensorNegativeDimension_basic", - "SortTensorSpecificDimension_basic", - "SortTensor_basic", - # Failure - incorrect dtype "ReduceMaxAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ee89cd09e..a46654e45 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1589,32 +1589,41 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> // ----- // CHECK-LABEL : func.func @test_top_k - func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { - // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - } +func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[AXIS:.*]] = torch.constant.int 1 + // CHECK: %[[LARGEST:.*]] = torch.constant.bool true + // CHECK: %[[SORTED:.*]] = torch.constant.bool true + // CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> +} // ----- // CHECK-LABEL: func.func @test_top_k_smallest - func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - } +func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[AXIS:.*]] = torch.constant.int 1 + // CHECK: %[[LARGEST:.*]] = torch.constant.bool false + // CHECK: %[[SORTED:.*]] = torch.constant.bool true + // CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> +} // ----- // CHECK-LABEL: func.func @test_top_k_negative_axis - func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { - // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) - return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - } +func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[AXIS:.*]] = torch.constant.int 1 + // CHECK: %[[LARGEST:.*]] = torch.constant.bool true + // CHECK: %[[SORTED:.*]] = torch.constant.bool true + // CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> +} // -----