From 378454ed853e2d48cd5ea8f4fdfb5bba2bc0f7a6 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 26 Aug 2024 17:24:36 +0000 Subject: [PATCH] onnxtotorch NonMaxSuppression lowering --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 851 +++++++++++++++--- 1 file changed, 739 insertions(+), 112 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index baac6d963..9b87a1d91 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3552,123 +3552,750 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "unimplemented: expected all 5 args to be present"); - // Squeeze the boxes and scores tensor. - // In Onnx, the shape of boxes is [BxNx4] while the - // torchvision expects it to be of shape [Nx4]. Similarly, for - // the scores tensor shape in Onnx is [BxCxN] while the - // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, boxes); - if (failed(squeezedBoxes)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze boxes tensor"); - FailureOr squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, scores); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value()); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - - boxes = squeezedBoxes.value(); - scores = squeezedScores.value(); - - // TODO: Add support for handling score_threshold arg. - // If score_threshold > min(scores) then the op can't be lowered since - // the torchvision::nms op doesn't have support for handling the - // score_threshold arg. - Value scoreThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[4]); - Value minScores = rewriter.create( - binder.getLoc(), - Torch::ValueTensorType::get(binder.op->getContext(), {}, - rewriter.getF32Type()), - scores); - minScores = rewriter.create( - binder.getLoc(), rewriter.getType(), minScores); - - Value scoresCond = rewriter.create( - binder.getLoc(), minScores, scoreThreshold); - rewriter.create( - binder.getLoc(), scoresCond, - rewriter.getStringAttr( - "unimplemented: score_threshold should be <= min(scores)")); - - Value iouThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[3]); - Value result = rewriter.create( - binder.getLoc(), resultType, boxes, scores, iouThreshold); - - // The result generated by torchvision.nms op is of shape [n], while the - // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor - // and make it of shape [n, 1] and then concatenate it with a zero - // tensor of shape [n, 2] to make it of shape [n, 3]. - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, dim); - if (failed(unsqueezedResult)) - return rewriter.notifyMatchFailure( - binder.op, "failed to unsqueeze result tensor"); - result = unsqueezedResult.value(); - - Value numOutputBoxes = rewriter.create( - binder.getLoc(), result, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); - SmallVector zerosShapeValues{numOutputBoxes}; - zerosShapeValues.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); - Value zerosShapeList = rewriter.create( - binder.getLoc(), - rewriter.getType( - rewriter.getType()), - zerosShapeValues); - - std::optional> resultShape = - cast(result.getType()).getOptionalSizes(); - if (!resultShape.has_value()) - return rewriter.notifyMatchFailure( - binder.op, "expected result tensor to have shape"); - llvm::SmallVector zerosShape = {resultShape->front(), 2}; - auto zerosTy = Torch::ValueTensorType::get( - resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(binder.getLoc()); - Value zeros = rewriter.create( - binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone, - cstNone); - - Type listElemType = - cast(resultType) - .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, - /*optionalDtype=*/nullptr); - Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - binder.op->getLoc(), listType, SmallVector{result, zeros}); - - // TODO: Add support for handling max_output_boxes_per_class arg. - // If numOutputBoxes (N) > max_output_boxes_per_class then the op can't - // be lowered since the torchvision::nms op doesn't have support for - // handling the max_output_boxes_per_class arg. Also, we have already - // constrained the number of classes to be 1 above, so the number of - // output boxes inferred from the result is num_output_boxes_per_class. Value maxOutputBoxesPerClass = rewriter.create( binder.getLoc(), rewriter.getType(), operands[2]); - Value boxesCond = rewriter.create( - binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass); - rewriter.create( - binder.getLoc(), boxesCond, - rewriter.getStringAttr( - "unimplemented: number of output boxes per class should be " - "<= max_output_boxes_per_class")); + Value IOUThreshold = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[3]); + Value scoreThreshold = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[4]); - rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, dim); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value two = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + Value three = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + Value none = rewriter.create(binder.getLoc()); + + auto boxesType = dyn_cast(boxes.getType()); + auto boxesSizes = + dyn_cast(boxes.getType()).getSizes(); + SmallVector boxesShape(boxesSizes); + int num_batches = boxesShape[0]; + + auto scoresType = dyn_cast(scores.getType()); + auto scoresSizes = + dyn_cast(scores.getType()).getSizes(); + SmallVector scoresShape(scoresSizes); + int num_classes = scoresShape[1]; + + Value selectedIndices = none; + + for (int batch_index = 0; batch_index < num_batches; batch_index++) { + for (int class_index = 0; class_index < num_classes; class_index++) { + + Value selectedBoxesInsideClass = zero; + + Value batchIndexValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(batch_index)); + Value batchIndexPlusOne = rewriter.create( + binder.getLoc(), batchIndexValue, one); + Value classIndexValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(class_index)); + Value classIndexPlusOne = rewriter.create( + binder.getLoc(), classIndexValue, one); + + // Originally, "boxes" is of shape [B,N,4]. Extract the boxes by + // batch_index of shape [1, N, 4], then squeeze to [N,4]. + SmallVector batchBoxesShape{1, boxesShape[1], + boxesShape[2]}; + auto batchBoxesType = rewriter.getType( + batchBoxesShape, boxesType.getOptionalDtype()); + Value batchBoxes = rewriter.create( + binder.getLoc(), batchBoxesType, boxes, /*dim=*/zero, + batchIndexValue, batchIndexPlusOne, one); + batchBoxes = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), + ArrayRef{boxesShape[1], boxesShape[2]}, + boxesType.getOptionalDtype()), + batchBoxes, zero); + + // Originally, "scores" is of shape [B,C,N], Extract the scores of + // batch_index by class_index to get [1,1,N] tensor and squeeze to a + // [N] tensor. + SmallVector classScoresShape{1, scoresShape[1], + scoresShape[2]}; + auto classScoresType = rewriter.getType( + classScoresShape, scoresType.getOptionalDtype()); + Value classScores = rewriter.create( + binder.getLoc(), classScoresType, scores, /*dim=*/zero, + batchIndexValue, batchIndexPlusOne, one); + + classScoresShape = {1, 1, scoresShape[2]}; + classScoresType = rewriter.getType( + classScoresShape, scoresType.getOptionalDtype()); + classScores = rewriter.create( + binder.getLoc(), classScoresType, classScores, /*dim*/ one, + classIndexValue, classIndexPlusOne, one); + classScores = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), + ArrayRef{1, scoresShape[2]}, + scoresType.getOptionalDtype()), + classScores, zero); + classScores = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{scoresShape[2]}, + scoresType.getOptionalDtype()), + classScores, zero); + + // Select each box at least once, unless at or below score threshold + for (int selected_box = 0; selected_box < scoresShape[2]; + selected_box++) { + // Get the index of the next top score in classScores + Value nextTopScoreIndex = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + resultType.getOptionalDtype()), + classScores, /*dim*/ zero, cstTrue); + nextTopScoreIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + nextTopScoreIndex); + + Value nextTopScoreIndexPlusOne = + rewriter.create(binder.getLoc(), + nextTopScoreIndex, one); + // Get the score of the next top score + Value nextTopScore = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + scoresType.getOptionalDtype()), + classScores, /*dim=*/zero, nextTopScoreIndex, + nextTopScoreIndexPlusOne, one); + nextTopScore = rewriter.create( + binder.getLoc(), rewriter.getType(), + nextTopScore); + // Get the box of the next top score + Value nextTopScoreBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), ArrayRef{1, 4}, + batchBoxesType.getOptionalDtype()), + batchBoxes, /*dim=*/zero, nextTopScoreIndex, + nextTopScoreIndexPlusOne, one); + nextTopScoreBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), ArrayRef{4}, + batchBoxesType.getOptionalDtype()), + nextTopScoreBox, zero); + + int numSelectedIndices; + if (isa(selectedIndices.getType())) { + numSelectedIndices = 0; + } else { + auto selectedIndicesSizes = + dyn_cast(selectedIndices.getType()) + .getSizes(); + SmallVector selectedIndicesShape(selectedIndicesSizes); + numSelectedIndices = selectedIndicesShape[0]; + } + Value suppressByIOUCond = cstFalse; + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond); + for (int selectedBoxIndex = 0; + selectedBoxIndex < numSelectedIndices; selectedBoxIndex++) { + Value selectIndex = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(selectedBoxIndex)); + Value selectIndexPlusOne = rewriter.create( + binder.getLoc(), selectIndex, one); + Value selectedBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1, 3}, + resultType.getOptionalDtype()), + selectedIndices, /*dim=*/zero, selectIndex, + selectIndexPlusOne, one); + selectedBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{3}, + resultType.getOptionalDtype()), + selectedBox, zero); + + selectedBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + resultType.getOptionalDtype()), + selectedBox, zero, two); + Value selectedBoxIdx = rewriter.create( + binder.getLoc(), rewriter.getType(), + selectedBox); + Value selectedBoxIdxPlusOne = + rewriter.create(binder.getLoc(), + selectedBoxIdx, one); + selectedBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), ArrayRef{1, 4}, + batchBoxesType.getOptionalDtype()), + batchBoxes, /*dim=*/zero, selectedBoxIdx, + selectedBoxIdxPlusOne, one); + selectedBox = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), ArrayRef{4}, + batchBoxesType.getOptionalDtype()), + selectedBox, zero); + // selectedBox, nextTopScoreBox [4] + if (centerPointBox == 0) { + Value x1_1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + selectedBox, zero, one); + x1_1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + x1_1); + + Value x1_2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + selectedBox, zero, three); + x1_2 = rewriter.create( + binder.getLoc(), rewriter.getType(), + x1_2); + + Value x2_1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + nextTopScoreBox, zero, one); + x2_1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + x2_1); + + Value x2_2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + nextTopScoreBox, zero, three); + x2_2 = rewriter.create( + binder.getLoc(), rewriter.getType(), + x2_2); + Value x1DataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{x1_1, x1_2}); + Value x1CstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value x1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + x1DataList, /*dtype=*/x1CstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value x1_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + x1); + x1_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + x1_min); + Value x1_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + x1); + x1_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + x1_max); + + Value x2DataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{x2_1, x2_2}); + Value x2CstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value x2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + x2DataList, /*dtype=*/x2CstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value x2_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + x2); + x2_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + x2_min); + Value x2_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + x2); + x2_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + x2_max); + + Value xminDataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{x1_min, x2_min}); + Value xminCstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value xmin = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + xminDataList, /*dtype=*/xminCstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value intersection_x_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + xmin); + intersection_x_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + intersection_x_min); + + Value xmaxDataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{x1_max, x2_max}); + Value xmaxCstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value xmax = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + xmaxDataList, /*dtype=*/xmaxCstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value intersection_x_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + xmax); + intersection_x_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + intersection_x_max); + + Value intersectionXCond = + rewriter.create(binder.getLoc(), + intersection_x_min, + intersection_x_max); + intersectionXCond = rewriter.create( + binder.getLoc(), suppressByIOUCond); + intersectionXCond = rewriter.create( + binder.getLoc(), suppressByIOUCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, intersectionXCond); + Value y1_1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + selectedBox, zero, zero); + y1_1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + y1_1); + + Value y1_2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + selectedBox, zero, two); + y1_2 = rewriter.create( + binder.getLoc(), rewriter.getType(), + y1_2); + + Value y2_1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + nextTopScoreBox, zero, zero); + y2_1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + y2_1); + + Value y2_2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + boxesType.getOptionalDtype()), + nextTopScoreBox, zero, two); + y2_2 = rewriter.create( + binder.getLoc(), rewriter.getType(), + y2_2); + Value y1DataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{y1_1, y1_2}); + Value y1CstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value y1 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + y1DataList, /*dtype=*/y1CstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value y1_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + y1); + y1_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + y1_min); + Value y1_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + y1); + y1_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + y1_max); + + Value y2DataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{y2_1, y2_2}); + Value y2CstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value y2 = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + y2DataList, /*dtype=*/y2CstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value y2_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + y2); + y2_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + y2_min); + Value y2_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + y2); + y2_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + y2_max); + Value yminDataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{y1_min, y2_min}); + Value yminCstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value ymin = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + yminDataList, /*dtype=*/yminCstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value intersection_y_min = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + ymin); + intersection_y_min = rewriter.create( + binder.getLoc(), rewriter.getType(), + intersection_y_min); + + Value ymaxDataList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{y1_max, y2_max}); + Value ymaxCstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value ymax = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{2}, + boxesType.getOptionalDtype()), + ymaxDataList, /*dtype=*/ymaxCstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value intersection_y_max = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + boxesType.getOptionalDtype()), + ymax); + intersection_y_max = rewriter.create( + binder.getLoc(), rewriter.getType(), + intersection_y_max); + + Value intersectionYCond = + rewriter.create(binder.getLoc(), + intersection_y_min, + intersection_y_max); + intersectionYCond = rewriter.create( + binder.getLoc(), intersectionYCond); + intersectionYCond = rewriter.create( + binder.getLoc(), intersectionYCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, intersectionYCond); + + Value intersectionYDelta = + rewriter.create( + binder.getLoc(), intersection_y_max, + intersection_y_min); + Value intersectionXDelta = + rewriter.create( + binder.getLoc(), intersection_x_max, + intersection_x_min); + Value intersectionArea = + rewriter.create( + binder.getLoc(), intersectionYDelta, + intersectionXDelta); + Value intersectionAreaCond = + rewriter.create( + binder.getLoc(), intersectionArea, zero); + intersectionAreaCond = rewriter.create( + binder.getLoc(), intersectionAreaCond); + + Value x1Delta = rewriter.create( + binder.getLoc(), x1_max, x1_min); + Value y1Delta = rewriter.create( + binder.getLoc(), y1_max, y1_min); + Value area1 = rewriter.create( + binder.getLoc(), x1Delta, y1Delta); + + Value x2Delta = rewriter.create( + binder.getLoc(), x2_max, x2_min); + Value y2Delta = rewriter.create( + binder.getLoc(), y2_max, y2_min); + Value area2 = rewriter.create( + binder.getLoc(), x2Delta, y2Delta); + Value unionArea = rewriter.create( + binder.getLoc(), rewriter.getType(), + area1, area2); + unionArea = rewriter.create( + binder.getLoc(), unionArea, intersectionArea); + Value area1Cond = rewriter.create( + binder.getLoc(), area1, zero); + area1Cond = rewriter.create( + binder.getLoc(), area1Cond); + Value area2Cond = rewriter.create( + binder.getLoc(), area2, zero); + area2Cond = rewriter.create( + binder.getLoc(), area2Cond); + Value unionAreaCond = rewriter.create( + binder.getLoc(), unionArea, zero); + unionAreaCond = rewriter.create( + binder.getLoc(), unionAreaCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, area1Cond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, area2Cond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, unionAreaCond); + + Value intersectionOverUnion = + rewriter.create( + binder.getLoc(), intersectionArea, unionArea); + Value IOUCond = rewriter.create( + binder.getLoc(), intersectionOverUnion, IOUThreshold); + IOUCond = rewriter.create( + binder.getLoc(), IOUCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, IOUCond); + } else { + // TODO: centerPointBox != 0 + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: suppress by iou, centerPointBox != 0"); + } + } + + Value dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{batchIndexValue, classIndexValue, + nextTopScoreIndex}); + Value cstDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Int)); + Value selectedIndex = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{3}, + resultType.getOptionalDtype()), + dataList, /*dtype=*/cstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + selectedIndex = rewriter.create( + binder.getLoc(), + rewriter.getType( + llvm::SmallVector{1, 3}, resultType.getDtype()), + selectedIndex, zero); + + Value numOutputBoxesCond = rewriter.create( + binder.getLoc(), selectedBoxesInsideClass, + maxOutputBoxesPerClass); + numOutputBoxesCond = rewriter.create( + binder.getLoc(), numOutputBoxesCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, numOutputBoxesCond); + + Value scoresCond = rewriter.create( + binder.getLoc(), scoreThreshold, nextTopScore); + scoresCond = rewriter.create( + binder.getLoc(), scoresCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond, scoresCond); + suppressByIOUCond = rewriter.create( + binder.getLoc(), suppressByIOUCond); + + // if suppressByIOUCond is true do not select the box + Value conditionBool = suppressByIOUCond; + llvm::SmallVector ifTypes; + Type elseType; + + if (isa(selectedIndices.getType())) { + elseType = rewriter.getType( + llvm::SmallVector{1, 3}, resultType.getDtype()); + } else { + auto selectedIndicesSizes = + dyn_cast(selectedIndices.getType()) + .getSizes(); + SmallVector selectedIndicesShape(selectedIndicesSizes); + elseType = rewriter.getType( + llvm::SmallVector{selectedIndicesShape[0] + 1, 3}, + resultType.getDtype()); + } + // ifTypes.push_back(selectedIndices.getType()); + ifTypes.push_back(elseType); + auto primIf = rewriter.create( + binder.getLoc(), elseType, conditionBool); + { + Region &elseRegion = primIf.getElseRegion(); + rewriter.createBlock(&elseRegion, elseRegion.end()); + + if (isa(selectedIndices.getType())) { + selectedIndices = selectedIndex; + } else { + Type listElemType = + cast(resultType) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = + rewriter.create( + binder.getLoc(), listType, + SmallVector{selectedIndices, selectedIndex}); + auto selectedIndicesSizes = dyn_cast( + selectedIndices.getType()) + .getSizes(); + SmallVector selectedIndicesShape( + selectedIndicesSizes); + selectedIndices = rewriter.create( + binder.getLoc(), + rewriter.getType( + llvm::SmallVector{ + selectedIndicesShape[0] + 1, 3}, + resultType.getDtype()), + tensorList, /*dim=*/zero); + } + selectedBoxesInsideClass = rewriter.create( + binder.getLoc(), selectedBoxesInsideClass, one); + rewriter.create( + binder.getLoc(), ValueRange{selectedIndices}); + } + { + Region &thenRegion = primIf.getThenRegion(); + rewriter.createBlock(&thenRegion, thenRegion.end()); + rewriter.create( + binder.getLoc(), ValueRange{selectedIndices}); + } + rewriter.setInsertionPointAfter(primIf); + + // set to scoreThreshold at maxScoreIndex in classScores so it + // does not select the same scoreIndex before other boxes with + // greater scores + Value replace = scoreThreshold; + dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{replace}); + cstDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + Value replaceTensor = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{1}, + scoresType.getOptionalDtype()), + dataList, /*dtype=*/cstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + Value insertEnd = rewriter.create( + binder.getLoc(), nextTopScoreIndex, one); + classScores = rewriter.create( + binder.getLoc(), classScores.getType(), classScores, + replaceTensor, + /*dim=*/zero, nextTopScoreIndex, insertEnd, /*step=*/one); + } + } + } + rewriter.replaceOp(binder.op, selectedIndices); return success(); }); }