onnxtotorch NonMaxSuppression lowering

pull/3677/head
Alex 2024-08-26 17:24:36 +00:00
parent 0a86deb59a
commit 378454ed85
1 changed files with 739 additions and 112 deletions

View File

@ -3552,123 +3552,750 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected all 5 args to be present");
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");
FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();
// TODO: Add support for handling score_threshold arg.
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), resultType, boxes, scores, iouThreshold);
// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();
Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), result,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "expected result tensor to have shape");
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
cstNone);
Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
// TODO: Add support for handling max_output_boxes_per_class arg.
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
// be lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg. Also, we have already
// constrained the number of classes to be 1 above, so the number of
// output boxes inferred from the result is num_output_boxes_per_class.
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));
Value IOUThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value two = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
Value three = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(3));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
auto boxesType = dyn_cast<Torch::ValueTensorType>(boxes.getType());
auto boxesSizes =
dyn_cast<Torch::ValueTensorType>(boxes.getType()).getSizes();
SmallVector<int64_t> boxesShape(boxesSizes);
int num_batches = boxesShape[0];
auto scoresType = dyn_cast<Torch::ValueTensorType>(scores.getType());
auto scoresSizes =
dyn_cast<Torch::ValueTensorType>(scores.getType()).getSizes();
SmallVector<int64_t> scoresShape(scoresSizes);
int num_classes = scoresShape[1];
Value selectedIndices = none;
for (int batch_index = 0; batch_index < num_batches; batch_index++) {
for (int class_index = 0; class_index < num_classes; class_index++) {
Value selectedBoxesInsideClass = zero;
Value batchIndexValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(batch_index));
Value batchIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), batchIndexValue, one);
Value classIndexValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(class_index));
Value classIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), classIndexValue, one);
// Originally, "boxes" is of shape [B,N,4]. Extract the boxes by
// batch_index of shape [1, N, 4], then squeeze to [N,4].
SmallVector<int64_t> batchBoxesShape{1, boxesShape[1],
boxesShape[2]};
auto batchBoxesType = rewriter.getType<Torch::ValueTensorType>(
batchBoxesShape, boxesType.getOptionalDtype());
Value batchBoxes = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), batchBoxesType, boxes, /*dim=*/zero,
batchIndexValue, batchIndexPlusOne, one);
batchBoxes = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(),
ArrayRef<int64_t>{boxesShape[1], boxesShape[2]},
boxesType.getOptionalDtype()),
batchBoxes, zero);
// Originally, "scores" is of shape [B,C,N], Extract the scores of
// batch_index by class_index to get [1,1,N] tensor and squeeze to a
// [N] tensor.
SmallVector<int64_t> classScoresShape{1, scoresShape[1],
scoresShape[2]};
auto classScoresType = rewriter.getType<Torch::ValueTensorType>(
classScoresShape, scoresType.getOptionalDtype());
Value classScores = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), classScoresType, scores, /*dim=*/zero,
batchIndexValue, batchIndexPlusOne, one);
classScoresShape = {1, 1, scoresShape[2]};
classScoresType = rewriter.getType<Torch::ValueTensorType>(
classScoresShape, scoresType.getOptionalDtype());
classScores = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), classScoresType, classScores, /*dim*/ one,
classIndexValue, classIndexPlusOne, one);
classScores = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(),
ArrayRef<int64_t>{1, scoresShape[2]},
scoresType.getOptionalDtype()),
classScores, zero);
classScores = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{scoresShape[2]},
scoresType.getOptionalDtype()),
classScores, zero);
// Select each box at least once, unless at or below score threshold
for (int selected_box = 0; selected_box < scoresShape[2];
selected_box++) {
// Get the index of the next top score in classScores
Value nextTopScoreIndex = rewriter.create<Torch::AtenArgmaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
resultType.getOptionalDtype()),
classScores, /*dim*/ zero, cstTrue);
nextTopScoreIndex = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
nextTopScoreIndex);
Value nextTopScoreIndexPlusOne =
rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(),
nextTopScoreIndex, one);
// Get the score of the next top score
Value nextTopScore = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
scoresType.getOptionalDtype()),
classScores, /*dim=*/zero, nextTopScoreIndex,
nextTopScoreIndexPlusOne, one);
nextTopScore = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
nextTopScore);
// Get the box of the next top score
Value nextTopScoreBox = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(), ArrayRef<int64_t>{1, 4},
batchBoxesType.getOptionalDtype()),
batchBoxes, /*dim=*/zero, nextTopScoreIndex,
nextTopScoreIndexPlusOne, one);
nextTopScoreBox = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(), ArrayRef<int64_t>{4},
batchBoxesType.getOptionalDtype()),
nextTopScoreBox, zero);
int numSelectedIndices;
if (isa<Torch::NoneType>(selectedIndices.getType())) {
numSelectedIndices = 0;
} else {
auto selectedIndicesSizes =
dyn_cast<Torch::ValueTensorType>(selectedIndices.getType())
.getSizes();
SmallVector<int64_t> selectedIndicesShape(selectedIndicesSizes);
numSelectedIndices = selectedIndicesShape[0];
}
Value suppressByIOUCond = cstFalse;
suppressByIOUCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), suppressByIOUCond);
for (int selectedBoxIndex = 0;
selectedBoxIndex < numSelectedIndices; selectedBoxIndex++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(selectedBoxIndex));
Value selectIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), selectIndex, one);
Value selectedBox = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1, 3},
resultType.getOptionalDtype()),
selectedIndices, /*dim=*/zero, selectIndex,
selectIndexPlusOne, one);
selectedBox = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{3},
resultType.getOptionalDtype()),
selectedBox, zero);
selectedBox = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
resultType.getOptionalDtype()),
selectedBox, zero, two);
Value selectedBoxIdx = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
selectedBox);
Value selectedBoxIdxPlusOne =
rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(),
selectedBoxIdx, one);
selectedBox = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(), ArrayRef<int64_t>{1, 4},
batchBoxesType.getOptionalDtype()),
batchBoxes, /*dim=*/zero, selectedBoxIdx,
selectedBoxIdxPlusOne, one);
selectedBox = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(), ArrayRef<int64_t>{4},
batchBoxesType.getOptionalDtype()),
selectedBox, zero);
// selectedBox, nextTopScoreBox [4]
if (centerPointBox == 0) {
Value x1_1 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
selectedBox, zero, one);
x1_1 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x1_1);
Value x1_2 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
selectedBox, zero, three);
x1_2 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x1_2);
Value x2_1 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
nextTopScoreBox, zero, one);
x2_1 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x2_1);
Value x2_2 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
nextTopScoreBox, zero, three);
x2_2 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x2_2);
Value x1DataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{x1_1, x1_2});
Value x1CstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value x1 = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
x1DataList, /*dtype=*/x1CstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value x1_min = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
x1);
x1_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x1_min);
Value x1_max = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
x1);
x1_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x1_max);
Value x2DataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{x2_1, x2_2});
Value x2CstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value x2 = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
x2DataList, /*dtype=*/x2CstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value x2_min = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
x2);
x2_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x2_min);
Value x2_max = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
x2);
x2_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
x2_max);
Value xminDataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{x1_min, x2_min});
Value xminCstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value xmin = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
xminDataList, /*dtype=*/xminCstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value intersection_x_min = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
xmin);
intersection_x_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
intersection_x_min);
Value xmaxDataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{x1_max, x2_max});
Value xmaxCstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value xmax = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
xmaxDataList, /*dtype=*/xmaxCstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value intersection_x_max = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
xmax);
intersection_x_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
intersection_x_max);
Value intersectionXCond =
rewriter.create<Torch::AtenGtFloatOp>(binder.getLoc(),
intersection_x_min,
intersection_x_max);
intersectionXCond = rewriter.create<Torch::Aten__Not__Op>(
binder.getLoc(), suppressByIOUCond);
intersectionXCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), suppressByIOUCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, intersectionXCond);
Value y1_1 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
selectedBox, zero, zero);
y1_1 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y1_1);
Value y1_2 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
selectedBox, zero, two);
y1_2 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y1_2);
Value y2_1 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
nextTopScoreBox, zero, zero);
y2_1 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y2_1);
Value y2_2 = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
boxesType.getOptionalDtype()),
nextTopScoreBox, zero, two);
y2_2 = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y2_2);
Value y1DataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{y1_1, y1_2});
Value y1CstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value y1 = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
y1DataList, /*dtype=*/y1CstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value y1_min = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
y1);
y1_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y1_min);
Value y1_max = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
y1);
y1_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y1_max);
Value y2DataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{y2_1, y2_2});
Value y2CstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value y2 = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
y2DataList, /*dtype=*/y2CstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value y2_min = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
y2);
y2_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y2_min);
Value y2_max = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
y2);
y2_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
y2_max);
Value yminDataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{y1_min, y2_min});
Value yminCstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value ymin = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
yminDataList, /*dtype=*/yminCstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value intersection_y_min = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
ymin);
intersection_y_min = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
intersection_y_min);
Value ymaxDataList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{y1_max, y2_max});
Value ymaxCstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value ymax = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{2},
boxesType.getOptionalDtype()),
ymaxDataList, /*dtype=*/ymaxCstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value intersection_y_max = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
boxesType.getOptionalDtype()),
ymax);
intersection_y_max = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
intersection_y_max);
Value intersectionYCond =
rewriter.create<Torch::AtenGtFloatOp>(binder.getLoc(),
intersection_y_min,
intersection_y_max);
intersectionYCond = rewriter.create<Torch::Aten__Not__Op>(
binder.getLoc(), intersectionYCond);
intersectionYCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), intersectionYCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, intersectionYCond);
Value intersectionYDelta =
rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), intersection_y_max,
intersection_y_min);
Value intersectionXDelta =
rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), intersection_x_max,
intersection_x_min);
Value intersectionArea =
rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), intersectionYDelta,
intersectionXDelta);
Value intersectionAreaCond =
rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), intersectionArea, zero);
intersectionAreaCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), intersectionAreaCond);
Value x1Delta = rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), x1_max, x1_min);
Value y1Delta = rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), y1_max, y1_min);
Value area1 = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), x1Delta, y1Delta);
Value x2Delta = rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), x2_max, x2_min);
Value y2Delta = rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), y2_max, y2_min);
Value area2 = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), x2Delta, y2Delta);
Value unionArea = rewriter.create<Torch::AtenAddOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
area1, area2);
unionArea = rewriter.create<Torch::AtenSubFloatOp>(
binder.getLoc(), unionArea, intersectionArea);
Value area1Cond = rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), area1, zero);
area1Cond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), area1Cond);
Value area2Cond = rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), area2, zero);
area2Cond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), area2Cond);
Value unionAreaCond = rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), unionArea, zero);
unionAreaCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), unionAreaCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, area1Cond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, area2Cond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, unionAreaCond);
Value intersectionOverUnion =
rewriter.create<Torch::AtenDivFloatOp>(
binder.getLoc(), intersectionArea, unionArea);
Value IOUCond = rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), intersectionOverUnion, IOUThreshold);
IOUCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), IOUCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, IOUCond);
} else {
// TODO: centerPointBox != 0
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: suppress by iou, centerPointBox != 0");
}
}
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{batchIndexValue, classIndexValue,
nextTopScoreIndex});
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Int));
Value selectedIndex = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{3},
resultType.getOptionalDtype()),
dataList, /*dtype=*/cstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
selectedIndex = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(),
rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1, 3}, resultType.getDtype()),
selectedIndex, zero);
Value numOutputBoxesCond = rewriter.create<Torch::AtenGtIntOp>(
binder.getLoc(), selectedBoxesInsideClass,
maxOutputBoxesPerClass);
numOutputBoxesCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), numOutputBoxesCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, numOutputBoxesCond);
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), scoreThreshold, nextTopScore);
scoresCond = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), scoresCond);
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), suppressByIOUCond, scoresCond);
suppressByIOUCond = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), suppressByIOUCond);
// if suppressByIOUCond is true do not select the box
Value conditionBool = suppressByIOUCond;
llvm::SmallVector<mlir::Type> ifTypes;
Type elseType;
if (isa<Torch::NoneType>(selectedIndices.getType())) {
elseType = rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1, 3}, resultType.getDtype());
} else {
auto selectedIndicesSizes =
dyn_cast<Torch::ValueTensorType>(selectedIndices.getType())
.getSizes();
SmallVector<int64_t> selectedIndicesShape(selectedIndicesSizes);
elseType = rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{selectedIndicesShape[0] + 1, 3},
resultType.getDtype());
}
// ifTypes.push_back(selectedIndices.getType());
ifTypes.push_back(elseType);
auto primIf = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), elseType, conditionBool);
{
Region &elseRegion = primIf.getElseRegion();
rewriter.createBlock(&elseRegion, elseRegion.end());
if (isa<Torch::NoneType>(selectedIndices.getType())) {
selectedIndices = selectedIndex;
} else {
Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), listType,
SmallVector<Value>{selectedIndices, selectedIndex});
auto selectedIndicesSizes = dyn_cast<Torch::ValueTensorType>(
selectedIndices.getType())
.getSizes();
SmallVector<int64_t> selectedIndicesShape(
selectedIndicesSizes);
selectedIndices = rewriter.create<Torch::AtenCatOp>(
binder.getLoc(),
rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{
selectedIndicesShape[0] + 1, 3},
resultType.getDtype()),
tensorList, /*dim=*/zero);
}
selectedBoxesInsideClass = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), selectedBoxesInsideClass, one);
rewriter.create<Torch::PrimIfYieldOp>(
binder.getLoc(), ValueRange{selectedIndices});
}
{
Region &thenRegion = primIf.getThenRegion();
rewriter.createBlock(&thenRegion, thenRegion.end());
rewriter.create<Torch::PrimIfYieldOp>(
binder.getLoc(), ValueRange{selectedIndices});
}
rewriter.setInsertionPointAfter(primIf);
// set to scoreThreshold at maxScoreIndex in classScores so it
// does not select the same scoreIndex before other boxes with
// greater scores
Value replace = scoreThreshold;
dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{replace});
cstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
Value replaceTensor = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{1},
scoresType.getOptionalDtype()),
dataList, /*dtype=*/cstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value insertEnd = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), nextTopScoreIndex, one);
classScores = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), classScores.getType(), classScores,
replaceTensor,
/*dim=*/zero, nextTopScoreIndex, insertEnd, /*step=*/one);
}
}
}
rewriter.replaceOp(binder.op, selectedIndices);
return success();
});
}