mirror of https://github.com/llvm/torch-mlir
onnxtotorch NonMaxSuppression lowering
parent
0a86deb59a
commit
378454ed85
|
@ -3552,123 +3552,750 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected all 5 args to be present");
|
||||
|
||||
// Squeeze the boxes and scores tensor.
|
||||
// In Onnx, the shape of boxes is [BxNx4] while the
|
||||
// torchvision expects it to be of shape [Nx4]. Similarly, for
|
||||
// the scores tensor shape in Onnx is [BxCxN] while the
|
||||
// torchvision expects it to be of shape [N].
|
||||
Value boxes = operands[0], scores = operands[1];
|
||||
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, boxes);
|
||||
if (failed(squeezedBoxes))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze boxes tensor");
|
||||
|
||||
FailureOr<Value> squeezedScores = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, scores);
|
||||
if (failed(squeezedScores))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze scores tensor");
|
||||
squeezedScores = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
|
||||
if (failed(squeezedScores))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze scores tensor");
|
||||
|
||||
boxes = squeezedBoxes.value();
|
||||
scores = squeezedScores.value();
|
||||
|
||||
// TODO: Add support for handling score_threshold arg.
|
||||
// If score_threshold > min(scores) then the op can't be lowered since
|
||||
// the torchvision::nms op doesn't have support for handling the
|
||||
// score_threshold arg.
|
||||
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
|
||||
Value IOUThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
|
||||
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
|
||||
Value minScores = rewriter.create<Torch::AtenMinOp>(
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value two = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
||||
Value three = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(3));
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
auto boxesType = dyn_cast<Torch::ValueTensorType>(boxes.getType());
|
||||
auto boxesSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(boxes.getType()).getSizes();
|
||||
SmallVector<int64_t> boxesShape(boxesSizes);
|
||||
int num_batches = boxesShape[0];
|
||||
|
||||
auto scoresType = dyn_cast<Torch::ValueTensorType>(scores.getType());
|
||||
auto scoresSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(scores.getType()).getSizes();
|
||||
SmallVector<int64_t> scoresShape(scoresSizes);
|
||||
int num_classes = scoresShape[1];
|
||||
|
||||
Value selectedIndices = none;
|
||||
|
||||
for (int batch_index = 0; batch_index < num_batches; batch_index++) {
|
||||
for (int class_index = 0; class_index < num_classes; class_index++) {
|
||||
|
||||
Value selectedBoxesInsideClass = zero;
|
||||
|
||||
Value batchIndexValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(batch_index));
|
||||
Value batchIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), batchIndexValue, one);
|
||||
Value classIndexValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(class_index));
|
||||
Value classIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), classIndexValue, one);
|
||||
|
||||
// Originally, "boxes" is of shape [B,N,4]. Extract the boxes by
|
||||
// batch_index of shape [1, N, 4], then squeeze to [N,4].
|
||||
SmallVector<int64_t> batchBoxesShape{1, boxesShape[1],
|
||||
boxesShape[2]};
|
||||
auto batchBoxesType = rewriter.getType<Torch::ValueTensorType>(
|
||||
batchBoxesShape, boxesType.getOptionalDtype());
|
||||
Value batchBoxes = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), batchBoxesType, boxes, /*dim=*/zero,
|
||||
batchIndexValue, batchIndexPlusOne, one);
|
||||
batchBoxes = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(),
|
||||
ArrayRef<int64_t>{boxesShape[1], boxesShape[2]},
|
||||
boxesType.getOptionalDtype()),
|
||||
batchBoxes, zero);
|
||||
|
||||
// Originally, "scores" is of shape [B,C,N], Extract the scores of
|
||||
// batch_index by class_index to get [1,1,N] tensor and squeeze to a
|
||||
// [N] tensor.
|
||||
SmallVector<int64_t> classScoresShape{1, scoresShape[1],
|
||||
scoresShape[2]};
|
||||
auto classScoresType = rewriter.getType<Torch::ValueTensorType>(
|
||||
classScoresShape, scoresType.getOptionalDtype());
|
||||
Value classScores = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), classScoresType, scores, /*dim=*/zero,
|
||||
batchIndexValue, batchIndexPlusOne, one);
|
||||
|
||||
classScoresShape = {1, 1, scoresShape[2]};
|
||||
classScoresType = rewriter.getType<Torch::ValueTensorType>(
|
||||
classScoresShape, scoresType.getOptionalDtype());
|
||||
classScores = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), classScoresType, classScores, /*dim*/ one,
|
||||
classIndexValue, classIndexPlusOne, one);
|
||||
classScores = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1, scoresShape[2]},
|
||||
scoresType.getOptionalDtype()),
|
||||
classScores, zero);
|
||||
classScores = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{scoresShape[2]},
|
||||
scoresType.getOptionalDtype()),
|
||||
classScores, zero);
|
||||
|
||||
// Select each box at least once, unless at or below score threshold
|
||||
for (int selected_box = 0; selected_box < scoresShape[2];
|
||||
selected_box++) {
|
||||
// Get the index of the next top score in classScores
|
||||
Value nextTopScoreIndex = rewriter.create<Torch::AtenArgmaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
resultType.getOptionalDtype()),
|
||||
classScores, /*dim*/ zero, cstTrue);
|
||||
nextTopScoreIndex = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
nextTopScoreIndex);
|
||||
|
||||
Value nextTopScoreIndexPlusOne =
|
||||
rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(),
|
||||
nextTopScoreIndex, one);
|
||||
// Get the score of the next top score
|
||||
Value nextTopScore = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
scoresType.getOptionalDtype()),
|
||||
classScores, /*dim=*/zero, nextTopScoreIndex,
|
||||
nextTopScoreIndexPlusOne, one);
|
||||
nextTopScore = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
nextTopScore);
|
||||
// Get the box of the next top score
|
||||
Value nextTopScoreBox = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), ArrayRef<int64_t>{1, 4},
|
||||
batchBoxesType.getOptionalDtype()),
|
||||
batchBoxes, /*dim=*/zero, nextTopScoreIndex,
|
||||
nextTopScoreIndexPlusOne, one);
|
||||
nextTopScoreBox = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), ArrayRef<int64_t>{4},
|
||||
batchBoxesType.getOptionalDtype()),
|
||||
nextTopScoreBox, zero);
|
||||
|
||||
int numSelectedIndices;
|
||||
if (isa<Torch::NoneType>(selectedIndices.getType())) {
|
||||
numSelectedIndices = 0;
|
||||
} else {
|
||||
auto selectedIndicesSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(selectedIndices.getType())
|
||||
.getSizes();
|
||||
SmallVector<int64_t> selectedIndicesShape(selectedIndicesSizes);
|
||||
numSelectedIndices = selectedIndicesShape[0];
|
||||
}
|
||||
Value suppressByIOUCond = cstFalse;
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), suppressByIOUCond);
|
||||
for (int selectedBoxIndex = 0;
|
||||
selectedBoxIndex < numSelectedIndices; selectedBoxIndex++) {
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(selectedBoxIndex));
|
||||
Value selectIndexPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), selectIndex, one);
|
||||
Value selectedBox = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1, 3},
|
||||
resultType.getOptionalDtype()),
|
||||
selectedIndices, /*dim=*/zero, selectIndex,
|
||||
selectIndexPlusOne, one);
|
||||
selectedBox = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{3},
|
||||
resultType.getOptionalDtype()),
|
||||
selectedBox, zero);
|
||||
|
||||
selectedBox = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
resultType.getOptionalDtype()),
|
||||
selectedBox, zero, two);
|
||||
Value selectedBoxIdx = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
selectedBox);
|
||||
Value selectedBoxIdxPlusOne =
|
||||
rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(),
|
||||
selectedBoxIdx, one);
|
||||
selectedBox = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), ArrayRef<int64_t>{1, 4},
|
||||
batchBoxesType.getOptionalDtype()),
|
||||
batchBoxes, /*dim=*/zero, selectedBoxIdx,
|
||||
selectedBoxIdxPlusOne, one);
|
||||
selectedBox = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), ArrayRef<int64_t>{4},
|
||||
batchBoxesType.getOptionalDtype()),
|
||||
selectedBox, zero);
|
||||
// selectedBox, nextTopScoreBox [4]
|
||||
if (centerPointBox == 0) {
|
||||
Value x1_1 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
selectedBox, zero, one);
|
||||
x1_1 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x1_1);
|
||||
|
||||
Value x1_2 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
selectedBox, zero, three);
|
||||
x1_2 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x1_2);
|
||||
|
||||
Value x2_1 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
nextTopScoreBox, zero, one);
|
||||
x2_1 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x2_1);
|
||||
|
||||
Value x2_2 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
nextTopScoreBox, zero, three);
|
||||
x2_2 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x2_2);
|
||||
Value x1DataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{x1_1, x1_2});
|
||||
Value x1CstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value x1 = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
x1DataList, /*dtype=*/x1CstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value x1_min = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
rewriter.getF32Type()),
|
||||
scores);
|
||||
minScores = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
|
||||
boxesType.getOptionalDtype()),
|
||||
x1);
|
||||
x1_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x1_min);
|
||||
Value x1_max = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
x1);
|
||||
x1_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x1_max);
|
||||
|
||||
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
|
||||
binder.getLoc(), minScores, scoreThreshold);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
binder.getLoc(), scoresCond,
|
||||
rewriter.getStringAttr(
|
||||
"unimplemented: score_threshold should be <= min(scores)"));
|
||||
Value x2DataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{x2_1, x2_2});
|
||||
Value x2CstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value x2 = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
x2DataList, /*dtype=*/x2CstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value x2_min = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
x2);
|
||||
x2_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x2_min);
|
||||
Value x2_max = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
x2);
|
||||
x2_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
x2_max);
|
||||
|
||||
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
|
||||
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
|
||||
binder.getLoc(), resultType, boxes, scores, iouThreshold);
|
||||
Value xminDataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{x1_min, x2_min});
|
||||
Value xminCstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value xmin = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
xminDataList, /*dtype=*/xminCstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value intersection_x_min = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
xmin);
|
||||
intersection_x_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
intersection_x_min);
|
||||
|
||||
// The result generated by torchvision.nms op is of shape [n], while the
|
||||
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
|
||||
// and make it of shape [n, 1] and then concatenate it with a zero
|
||||
// tensor of shape [n, 2] to make it of shape [n, 3].
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
FailureOr<Value> unsqueezedResult =
|
||||
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
|
||||
if (failed(unsqueezedResult))
|
||||
Value xmaxDataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{x1_max, x2_max});
|
||||
Value xmaxCstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value xmax = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
xmaxDataList, /*dtype=*/xmaxCstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value intersection_x_max = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
xmax);
|
||||
intersection_x_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
intersection_x_max);
|
||||
|
||||
Value intersectionXCond =
|
||||
rewriter.create<Torch::AtenGtFloatOp>(binder.getLoc(),
|
||||
intersection_x_min,
|
||||
intersection_x_max);
|
||||
intersectionXCond = rewriter.create<Torch::Aten__Not__Op>(
|
||||
binder.getLoc(), suppressByIOUCond);
|
||||
intersectionXCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), suppressByIOUCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, intersectionXCond);
|
||||
Value y1_1 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
selectedBox, zero, zero);
|
||||
y1_1 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y1_1);
|
||||
|
||||
Value y1_2 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
selectedBox, zero, two);
|
||||
y1_2 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y1_2);
|
||||
|
||||
Value y2_1 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
nextTopScoreBox, zero, zero);
|
||||
y2_1 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y2_1);
|
||||
|
||||
Value y2_2 = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
boxesType.getOptionalDtype()),
|
||||
nextTopScoreBox, zero, two);
|
||||
y2_2 = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y2_2);
|
||||
Value y1DataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{y1_1, y1_2});
|
||||
Value y1CstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value y1 = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
y1DataList, /*dtype=*/y1CstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value y1_min = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
y1);
|
||||
y1_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y1_min);
|
||||
Value y1_max = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
y1);
|
||||
y1_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y1_max);
|
||||
|
||||
Value y2DataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{y2_1, y2_2});
|
||||
Value y2CstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value y2 = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
y2DataList, /*dtype=*/y2CstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value y2_min = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
y2);
|
||||
y2_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y2_min);
|
||||
Value y2_max = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
y2);
|
||||
y2_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
y2_max);
|
||||
Value yminDataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{y1_min, y2_min});
|
||||
Value yminCstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value ymin = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
yminDataList, /*dtype=*/yminCstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value intersection_y_min = rewriter.create<Torch::AtenMaxOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
ymin);
|
||||
intersection_y_min = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
intersection_y_min);
|
||||
|
||||
Value ymaxDataList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{y1_max, y2_max});
|
||||
Value ymaxCstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value ymax = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{2},
|
||||
boxesType.getOptionalDtype()),
|
||||
ymaxDataList, /*dtype=*/ymaxCstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value intersection_y_max = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
boxesType.getOptionalDtype()),
|
||||
ymax);
|
||||
intersection_y_max = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
intersection_y_max);
|
||||
|
||||
Value intersectionYCond =
|
||||
rewriter.create<Torch::AtenGtFloatOp>(binder.getLoc(),
|
||||
intersection_y_min,
|
||||
intersection_y_max);
|
||||
intersectionYCond = rewriter.create<Torch::Aten__Not__Op>(
|
||||
binder.getLoc(), intersectionYCond);
|
||||
intersectionYCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), intersectionYCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, intersectionYCond);
|
||||
|
||||
Value intersectionYDelta =
|
||||
rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), intersection_y_max,
|
||||
intersection_y_min);
|
||||
Value intersectionXDelta =
|
||||
rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), intersection_x_max,
|
||||
intersection_x_min);
|
||||
Value intersectionArea =
|
||||
rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), intersectionYDelta,
|
||||
intersectionXDelta);
|
||||
Value intersectionAreaCond =
|
||||
rewriter.create<Torch::AtenGtFloatOp>(
|
||||
binder.getLoc(), intersectionArea, zero);
|
||||
intersectionAreaCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), intersectionAreaCond);
|
||||
|
||||
Value x1Delta = rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), x1_max, x1_min);
|
||||
Value y1Delta = rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), y1_max, y1_min);
|
||||
Value area1 = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), x1Delta, y1Delta);
|
||||
|
||||
Value x2Delta = rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), x2_max, x2_min);
|
||||
Value y2Delta = rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), y2_max, y2_min);
|
||||
Value area2 = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), x2Delta, y2Delta);
|
||||
Value unionArea = rewriter.create<Torch::AtenAddOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
area1, area2);
|
||||
unionArea = rewriter.create<Torch::AtenSubFloatOp>(
|
||||
binder.getLoc(), unionArea, intersectionArea);
|
||||
Value area1Cond = rewriter.create<Torch::AtenGtFloatOp>(
|
||||
binder.getLoc(), area1, zero);
|
||||
area1Cond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), area1Cond);
|
||||
Value area2Cond = rewriter.create<Torch::AtenGtFloatOp>(
|
||||
binder.getLoc(), area2, zero);
|
||||
area2Cond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), area2Cond);
|
||||
Value unionAreaCond = rewriter.create<Torch::AtenGtFloatOp>(
|
||||
binder.getLoc(), unionArea, zero);
|
||||
unionAreaCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), unionAreaCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, area1Cond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, area2Cond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, unionAreaCond);
|
||||
|
||||
Value intersectionOverUnion =
|
||||
rewriter.create<Torch::AtenDivFloatOp>(
|
||||
binder.getLoc(), intersectionArea, unionArea);
|
||||
Value IOUCond = rewriter.create<Torch::AtenGtFloatOp>(
|
||||
binder.getLoc(), intersectionOverUnion, IOUThreshold);
|
||||
IOUCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), IOUCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, IOUCond);
|
||||
} else {
|
||||
// TODO: centerPointBox != 0
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "failed to unsqueeze result tensor");
|
||||
result = unsqueezedResult.value();
|
||||
binder.op,
|
||||
"unimplemented: suppress by iou, centerPointBox != 0");
|
||||
}
|
||||
}
|
||||
|
||||
Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), result,
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||
SmallVector<Value> zerosShapeValues{numOutputBoxes};
|
||||
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
|
||||
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
zerosShapeValues);
|
||||
SmallVector<Value>{batchIndexValue, classIndexValue,
|
||||
nextTopScoreIndex});
|
||||
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Int));
|
||||
Value selectedIndex = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{3},
|
||||
resultType.getOptionalDtype()),
|
||||
dataList, /*dtype=*/cstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
selectedIndex = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, 3}, resultType.getDtype()),
|
||||
selectedIndex, zero);
|
||||
|
||||
std::optional<ArrayRef<int64_t>> resultShape =
|
||||
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
|
||||
if (!resultShape.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected result tensor to have shape");
|
||||
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
|
||||
auto zerosTy = Torch::ValueTensorType::get(
|
||||
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value zeros = rewriter.create<Torch::AtenZerosOp>(
|
||||
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
|
||||
cstNone);
|
||||
Value numOutputBoxesCond = rewriter.create<Torch::AtenGtIntOp>(
|
||||
binder.getLoc(), selectedBoxesInsideClass,
|
||||
maxOutputBoxesPerClass);
|
||||
numOutputBoxesCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), numOutputBoxesCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, numOutputBoxesCond);
|
||||
|
||||
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
|
||||
binder.getLoc(), scoreThreshold, nextTopScore);
|
||||
scoresCond = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), scoresCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond, scoresCond);
|
||||
suppressByIOUCond = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), suppressByIOUCond);
|
||||
|
||||
// if suppressByIOUCond is true do not select the box
|
||||
Value conditionBool = suppressByIOUCond;
|
||||
llvm::SmallVector<mlir::Type> ifTypes;
|
||||
Type elseType;
|
||||
|
||||
if (isa<Torch::NoneType>(selectedIndices.getType())) {
|
||||
elseType = rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, 3}, resultType.getDtype());
|
||||
} else {
|
||||
auto selectedIndicesSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(selectedIndices.getType())
|
||||
.getSizes();
|
||||
SmallVector<int64_t> selectedIndicesShape(selectedIndicesSizes);
|
||||
elseType = rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{selectedIndicesShape[0] + 1, 3},
|
||||
resultType.getDtype());
|
||||
}
|
||||
// ifTypes.push_back(selectedIndices.getType());
|
||||
ifTypes.push_back(elseType);
|
||||
auto primIf = rewriter.create<Torch::PrimIfOp>(
|
||||
binder.getLoc(), elseType, conditionBool);
|
||||
{
|
||||
Region &elseRegion = primIf.getElseRegion();
|
||||
rewriter.createBlock(&elseRegion, elseRegion.end());
|
||||
|
||||
if (isa<Torch::NoneType>(selectedIndices.getType())) {
|
||||
selectedIndices = selectedIndex;
|
||||
} else {
|
||||
Type listElemType =
|
||||
cast<Torch::BaseTensorType>(resultType)
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
|
||||
Value tensorList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(), listType,
|
||||
SmallVector<Value>{selectedIndices, selectedIndex});
|
||||
auto selectedIndicesSizes = dyn_cast<Torch::ValueTensorType>(
|
||||
selectedIndices.getType())
|
||||
.getSizes();
|
||||
SmallVector<int64_t> selectedIndicesShape(
|
||||
selectedIndicesSizes);
|
||||
selectedIndices = rewriter.create<Torch::AtenCatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{
|
||||
selectedIndicesShape[0] + 1, 3},
|
||||
resultType.getDtype()),
|
||||
tensorList, /*dim=*/zero);
|
||||
}
|
||||
selectedBoxesInsideClass = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), selectedBoxesInsideClass, one);
|
||||
rewriter.create<Torch::PrimIfYieldOp>(
|
||||
binder.getLoc(), ValueRange{selectedIndices});
|
||||
}
|
||||
{
|
||||
Region &thenRegion = primIf.getThenRegion();
|
||||
rewriter.createBlock(&thenRegion, thenRegion.end());
|
||||
rewriter.create<Torch::PrimIfYieldOp>(
|
||||
binder.getLoc(), ValueRange{selectedIndices});
|
||||
}
|
||||
rewriter.setInsertionPointAfter(primIf);
|
||||
|
||||
// TODO: Add support for handling max_output_boxes_per_class arg.
|
||||
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
|
||||
// be lowered since the torchvision::nms op doesn't have support for
|
||||
// handling the max_output_boxes_per_class arg. Also, we have already
|
||||
// constrained the number of classes to be 1 above, so the number of
|
||||
// output boxes inferred from the result is num_output_boxes_per_class.
|
||||
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
|
||||
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
|
||||
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
binder.getLoc(), boxesCond,
|
||||
rewriter.getStringAttr(
|
||||
"unimplemented: number of output boxes per class should be "
|
||||
"<= max_output_boxes_per_class"));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
|
||||
tensorList, dim);
|
||||
// set to scoreThreshold at maxScoreIndex in classScores so it
|
||||
// does not select the same scoreIndex before other boxes with
|
||||
// greater scores
|
||||
Value replace = scoreThreshold;
|
||||
dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{replace});
|
||||
cstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
Value replaceTensor = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{1},
|
||||
scoresType.getOptionalDtype()),
|
||||
dataList, /*dtype=*/cstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
Value insertEnd = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), nextTopScoreIndex, one);
|
||||
classScores = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||
binder.getLoc(), classScores.getType(), classScores,
|
||||
replaceTensor,
|
||||
/*dim=*/zero, nextTopScoreIndex, insertEnd, /*step=*/one);
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(binder.op, selectedIndices);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue