mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix Onnx.TopK lowering (#3103)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3107/head
parent
7e778e2179
commit
af54d27820
|
@ -2120,20 +2120,20 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"TopK", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType Values_type, Indices_type;
|
Torch::ValueTensorType Values_type, Indices_type;
|
||||||
Value X, K;
|
Value input, kValue;
|
||||||
int64_t axis;
|
int64_t axis;
|
||||||
bool largest, sorted;
|
bool largest, sorted;
|
||||||
if (binder.tensorOperandAtIndex(X, 0) ||
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||||
binder.tensorOperandAtIndex(K, 1) ||
|
binder.tensorOperandAtIndex(kValue, 1) ||
|
||||||
binder.s64IntegerAttr(axis, "axis", -1) ||
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
||||||
binder.s64BoolAttr(largest, "largest", true) ||
|
binder.s64BoolAttr(largest, "largest", true) ||
|
||||||
binder.s64BoolAttr(sorted, "sorted", true) ||
|
binder.s64BoolAttr(sorted, "sorted", true) ||
|
||||||
binder.tensorResultTypeAtIndex(Values_type, 0) ||
|
binder.tensorResultTypeAtIndex(Values_type, 0) ||
|
||||||
binder.tensorResultTypeAtIndex(Indices_type, 1))
|
binder.tensorResultTypeAtIndex(Indices_type, 1))
|
||||||
return failure();
|
return failure();
|
||||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(X);
|
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
|
||||||
if (!maybeRank)
|
if (!maybeRank)
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"Unimplemented: unranked tensor");
|
"Unimplemented: unranked tensor");
|
||||||
|
@ -2145,9 +2145,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest);
|
||||||
Value cstSorted =
|
Value cstSorted =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted);
|
||||||
|
Value kValueInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), kValue);
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>(
|
||||||
binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest,
|
binder.op, Values_type, Indices_type, input, kValueInt, cstAxis,
|
||||||
cstSorted);
|
cstLargest, cstSorted);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("Sign", 9,
|
patterns.onOp("Sign", 9,
|
||||||
|
|
|
@ -2042,13 +2042,6 @@ ONNX_XFAIL_SET = {
|
||||||
"SqueezeModule_broadcast",
|
"SqueezeModule_broadcast",
|
||||||
"SqueezeModule_static",
|
"SqueezeModule_static",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.TopK
|
|
||||||
"SortTensorDescending_basic",
|
|
||||||
"SortTensorInteger_basic",
|
|
||||||
"SortTensorNegativeDimension_basic",
|
|
||||||
"SortTensorSpecificDimension_basic",
|
|
||||||
"SortTensor_basic",
|
|
||||||
|
|
||||||
# Failure - incorrect dtype
|
# Failure - incorrect dtype
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
|
|
|
@ -1589,32 +1589,41 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL : func.func @test_top_k
|
// CHECK-LABEL : func.func @test_top_k
|
||||||
func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
|
func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
|
||||||
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||||
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[LARGEST:.*]] = torch.constant.bool true
|
||||||
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
|
||||||
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
}
|
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
||||||
|
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_top_k_smallest
|
// CHECK-LABEL: func.func @test_top_k_smallest
|
||||||
func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||||
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[LARGEST:.*]] = torch.constant.bool false
|
||||||
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
|
||||||
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
}
|
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
||||||
|
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_top_k_negative_axis
|
// CHECK-LABEL: func.func @test_top_k_negative_axis
|
||||||
func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
|
func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
|
||||||
// CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||||
// CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[LARGEST:.*]] = torch.constant.bool true
|
||||||
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
// CHECK: %[[SORTED:.*]] = torch.constant.bool true
|
||||||
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
// CHECK: %[[K:.*]] = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
}
|
// CHECK: torch.aten.topk %arg0, %[[K]], %[[AXIS]], %[[LARGEST]], %[[SORTED]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
%0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>)
|
||||||
|
return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue