mirror of https://github.com/llvm/torch-mlir
Fix concat order of result
parent
7e69602cad
commit
e1c19ec319
|
@ -3717,10 +3717,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
boxes = squeezedBoxes.value();
|
||||
scores = squeezedScores.value();
|
||||
|
||||
// TODO: Add support for handling score_threshold arg.
|
||||
// If score_threshold > min(scores) then the op can't be lowered since
|
||||
// the torchvision::nms op doesn't have support for handling the
|
||||
// score_threshold arg.
|
||||
// TODO: Support score_threshold input
|
||||
// Filter out the boxes if the score < score_threshold
|
||||
if (operands.size() == 5) {
|
||||
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
|
@ -3742,6 +3740,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
"unimplemented: score_threshold should be <= min(scores)"));
|
||||
}
|
||||
|
||||
// TODO: Support default iou_threshold
|
||||
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
|
||||
auto nmsTy = Torch::ValueTensorType::get(
|
||||
|
@ -3796,14 +3795,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
|
||||
binder.getLoc(), listType, SmallVector<Value>{zeros, result});
|
||||
|
||||
// TODO: Add support for handling max_output_boxes_per_class arg.
|
||||
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
|
||||
// be lowered since the torchvision::nms op doesn't have support for
|
||||
// handling the max_output_boxes_per_class arg. Also, we have already
|
||||
// constrained the number of classes to be 1 above, so the number of
|
||||
// output boxes inferred from the result is num_output_boxes_per_class.
|
||||
// TODO: Support max_output_boxes_per_class input
|
||||
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
|
||||
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
|
||||
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
|
||||
|
|
Loading…
Reference in New Issue