[ONNX] Add OnnxToTorch lowering for NonMaxSuppression op (#3501)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3523/head
Vivek Khandelwal 2024-07-26 21:01:27 +05:30 committed by GitHub
parent ea60d72489
commit b6e4725259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 298 additions and 10 deletions

View File

@ -17145,3 +17145,28 @@ def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [
}];
}
def Torch_TorchvisionNmsOp : Torch_Op<"torchvision.nms", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `torchvision::nms : (Tensor, Tensor, float) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$dets,
AnyTorchTensorType:$scores,
Torch_FloatType:$iou_threshold
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchvisionNmsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void TorchvisionNmsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

View File

@ -3050,4 +3050,145 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
return success();
});
patterns.onOp(
"NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
int64_t centerPointBox;
if (binder.tensorOperandsList(operands) ||
binder.s64IntegerAttr(centerPointBox, "center_point_box", 0) ||
binder.tensorResultType(resultType))
return failure();
// TODO: Add support for non-zero center_point_box value.
if (centerPointBox != 0)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");
// TODO: Add support for optional arguments to be absent.
if (operands.size() != 5)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected all 5 args to be present");
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");
FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();
// TODO: Add support for handling score_threshold arg.
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), resultType, boxes, scores, iouThreshold);
// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();
Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), result,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "expected result tensor to have shape");
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
cstNone);
Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
// TODO: Add support for handling max_output_boxes_per_class arg.
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
// be lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg. Also, we have already
// constrained the number of classes to be 1 above, so the number of
// output boxes inferred from the result is num_output_boxes_per_class.
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
return success();
});
}

View File

@ -6285,6 +6285,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.torchvision.nms\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n"
" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n"
" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n"
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n"
" return %2 : !torch.int\n"
" }\n"
" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n"
" %none = torch.constant.none\n"
" return %none : !torch.none\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.torchvision.nms\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float) -> !torch.int {\n"
" %int3 = torch.constant.int 3\n"
" return %int3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
" %true = torch.constant.bool true\n"
@ -10592,16 +10612,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n"
" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n"
" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n"
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n"
" return %2 : !torch.int\n"
" }\n"
" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n"
" %none = torch.constant.none\n"
" return %none : !torch.none\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"

View File

@ -99,6 +99,12 @@ def torchvisionroi_pool〡shape(input: List[int], rois: List[int], spatial_sc
def torchvisionroi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]:
return (input_rank_dtype[1], torch.int64)
def torchvisionnms〡shape(dets: List[int], scores: List[int], iou_threshold: float) -> List[int]:
return [hacky_get_unknown_dimension_size(), len(dets)]
def torchvisionnms〡dtype(dets_rank_dtype: Tuple[int, int], scores_rank_dtype: Tuple[int, int], iou_threshold: float) -> int:
return torch.int
@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.

View File

@ -1197,6 +1197,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)"
)
emit("torchvision::nms : (Tensor, Tensor, float) -> (Tensor)")
def dump_registered_ops(outfile: TextIO, registry: Registry):

View File

@ -1805,3 +1805,108 @@ func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtens
}
return %0 : !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func.func @test_nonmaxsuppression_identical_boxes(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,10,4],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,10],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4],f32>, %arg1: !torch.vtensor<[1,1,10],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.vtensor<[10,4],f32>
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1,3],si64>
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64>
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}
// -----
// CHECK-LABEL: func.func @test_nonmaxsuppression_single_box(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""}
func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1,3],si64>
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64>
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
// CHECK: }
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}