[MLIR][TORCH] Fix Onnx.TopK lowering (#3103)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3107/head
Vivek Khandelwal 2024-04-03 22:12:48 +05:30 committed by GitHub
parent 7e778e2179
commit af54d27820
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 32 deletions

View File

@ -2120,20 +2120,20 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success(); return success();
}); });
patterns.onOp( patterns.onOp(
"Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { "TopK", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType Values_type, Indices_type; Torch::ValueTensorType Values_type, Indices_type;
Value X, K; Value input, kValue;
int64_t axis; int64_t axis;
bool largest, sorted; bool largest, sorted;
if (binder.tensorOperandAtIndex(X, 0) || if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(K, 1) || binder.tensorOperandAtIndex(kValue, 1) ||
binder.s64IntegerAttr(axis, "axis", -1) || binder.s64IntegerAttr(axis, "axis", -1) ||
binder.s64BoolAttr(largest, "largest", true) || binder.s64BoolAttr(largest, "largest", true) ||
binder.s64BoolAttr(sorted, "sorted", true) || binder.s64BoolAttr(sorted, "sorted", true) ||
binder.tensorResultTypeAtIndex(Values_type, 0) || binder.tensorResultTypeAtIndex(Values_type, 0) ||
binder.tensorResultTypeAtIndex(Indices_type, 1)) binder.tensorResultTypeAtIndex(Indices_type, 1))
return failure(); return failure();
std::optional<unsigned> maybeRank = Torch::getTensorRank(X); std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
if (!maybeRank) if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor"); "Unimplemented: unranked tensor");
@ -2145,9 +2145,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest);
Value cstSorted = Value cstSorted =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted);
Value kValueInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), kValue);
rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>( rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>(
binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest, binder.op, Values_type, Indices_type, input, kValueInt, cstAxis,
cstSorted); cstLargest, cstSorted);
return success(); return success();
}); });
patterns.onOp("Sign", 9, patterns.onOp("Sign", 9,

View File

@ -2042,13 +2042,6 @@ ONNX_XFAIL_SET = {
"SqueezeModule_broadcast", "SqueezeModule_broadcast",
"SqueezeModule_static", "SqueezeModule_static",
# Failure - onnx_lowering: onnx.TopK
"SortTensorDescending_basic",
"SortTensorInteger_basic",
"SortTensorNegativeDimension_basic",
"SortTensorSpecificDimension_basic",
"SortTensor_basic",
# Failure - incorrect dtype # Failure - incorrect dtype
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",

View File

@ -1590,8 +1590,11 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
// CHECK-LABEL : func.func @test_top_k // CHECK-LABEL : func.func @test_top_k
func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) // CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> // CHECK: %[[LARGEST:.*]] = torch.constant.bool true
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
} }
@ -1600,8 +1603,11 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
// CHECK-LABEL: func.func @test_top_k_smallest // CHECK-LABEL: func.func @test_top_k_smallest
func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) // CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> // CHECK: %[[LARGEST:.*]] = torch.constant.bool false
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
} }
@ -1610,8 +1616,11 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
// CHECK-LABEL: func.func @test_top_k_negative_axis // CHECK-LABEL: func.func @test_top_k_negative_axis
func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) // CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> // CHECK: %[[LARGEST:.*]] = torch.constant.bool true
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
} }