Fix concat order of result

jinchen62 2024-11-25 21:13:03 -08:00
parent 7e69602cad
commit e1c19ec319
1 changed files with 6 additions and 11 deletions

View File

@ -3717,10 +3717,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
boxes = squeezedBoxes.value();
scores = squeezedScores.value();
// TODO: Add support for handling score_threshold arg.
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
@ -3742,6 +3740,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"unimplemented: score_threshold should be <= min(scores)"));
}
// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
auto nmsTy = Torch::ValueTensorType::get(
@ -3796,14 +3795,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
binder.getLoc(), listType, SmallVector<Value>{zeros, result});
// TODO: Add support for handling max_output_boxes_per_class arg.
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
// be lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg. Also, we have already
// constrained the number of classes to be 1 above, so the number of
// output boxes inferred from the result is num_output_boxes_per_class.
// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(