mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for NonMaxSuppression op (#3501)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3523/head
parent
ea60d72489
commit
b6e4725259
|
@ -17145,3 +17145,28 @@ def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_TorchvisionNmsOp : Torch_Op<"torchvision.nms", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `torchvision::nms : (Tensor, Tensor, float) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$dets,
|
||||
AnyTorchTensorType:$scores,
|
||||
Torch_FloatType:$iou_threshold
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult TorchvisionNmsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void TorchvisionNmsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -3050,4 +3050,145 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"NonMaxSuppression", 10,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
SmallVector<Value> operands;
|
||||
int64_t centerPointBox;
|
||||
if (binder.tensorOperandsList(operands) ||
|
||||
binder.s64IntegerAttr(centerPointBox, "center_point_box", 0) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
// TODO: Add support for non-zero center_point_box value.
|
||||
if (centerPointBox != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected center_point_box "
|
||||
"attribute value to be 0");
|
||||
|
||||
// TODO: Add support for optional arguments to be absent.
|
||||
if (operands.size() != 5)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected all 5 args to be present");
|
||||
|
||||
// Squeeze the boxes and scores tensor.
|
||||
// In Onnx, the shape of boxes is [BxNx4] while the
|
||||
// torchvision expects it to be of shape [Nx4]. Similarly, for
|
||||
// the scores tensor shape in Onnx is [BxCxN] while the
|
||||
// torchvision expects it to be of shape [N].
|
||||
Value boxes = operands[0], scores = operands[1];
|
||||
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, boxes);
|
||||
if (failed(squeezedBoxes))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze boxes tensor");
|
||||
|
||||
FailureOr<Value> squeezedScores = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, scores);
|
||||
if (failed(squeezedScores))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze scores tensor");
|
||||
squeezedScores = Torch::squeezeTensor(
|
||||
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
|
||||
if (failed(squeezedScores))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"failed to squeeze scores tensor");
|
||||
|
||||
boxes = squeezedBoxes.value();
|
||||
scores = squeezedScores.value();
|
||||
|
||||
// TODO: Add support for handling score_threshold arg.
|
||||
// If score_threshold > min(scores) then the op can't be lowered since
|
||||
// the torchvision::nms op doesn't have support for handling the
|
||||
// score_threshold arg.
|
||||
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
|
||||
Value minScores = rewriter.create<Torch::AtenMinOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(), {},
|
||||
rewriter.getF32Type()),
|
||||
scores);
|
||||
minScores = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
|
||||
|
||||
Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
|
||||
binder.getLoc(), minScores, scoreThreshold);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
binder.getLoc(), scoresCond,
|
||||
rewriter.getStringAttr(
|
||||
"unimplemented: score_threshold should be <= min(scores)"));
|
||||
|
||||
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
|
||||
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
|
||||
binder.getLoc(), resultType, boxes, scores, iouThreshold);
|
||||
|
||||
// The result generated by torchvision.nms op is of shape [n], while the
|
||||
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
|
||||
// and make it of shape [n, 1] and then concatenate it with a zero
|
||||
// tensor of shape [n, 2] to make it of shape [n, 3].
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
FailureOr<Value> unsqueezedResult =
|
||||
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
|
||||
if (failed(unsqueezedResult))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "failed to unsqueeze result tensor");
|
||||
result = unsqueezedResult.value();
|
||||
|
||||
Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), result,
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||
SmallVector<Value> zerosShapeValues{numOutputBoxes};
|
||||
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
|
||||
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
zerosShapeValues);
|
||||
|
||||
std::optional<ArrayRef<int64_t>> resultShape =
|
||||
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
|
||||
if (!resultShape.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected result tensor to have shape");
|
||||
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
|
||||
auto zerosTy = Torch::ValueTensorType::get(
|
||||
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value zeros = rewriter.create<Torch::AtenZerosOp>(
|
||||
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
|
||||
cstNone);
|
||||
|
||||
Type listElemType =
|
||||
cast<Torch::BaseTensorType>(resultType)
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});
|
||||
|
||||
// TODO: Add support for handling max_output_boxes_per_class arg.
|
||||
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
|
||||
// be lowered since the torchvision::nms op doesn't have support for
|
||||
// handling the max_output_boxes_per_class arg. Also, we have already
|
||||
// constrained the number of classes to be 1 above, so the number of
|
||||
// output boxes inferred from the result is num_output_boxes_per_class.
|
||||
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
|
||||
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
|
||||
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
binder.getLoc(), boxesCond,
|
||||
rewriter.getStringAttr(
|
||||
"unimplemented: number of output boxes per class should be "
|
||||
"<= max_output_boxes_per_class"));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
|
||||
tensorList, dim);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -6285,6 +6285,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.torchvision.nms\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n"
|
||||
" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n"
|
||||
" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n"
|
||||
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" return %none : !torch.none\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.torchvision.nms\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float) -> !torch.int {\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" return %int3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
|
@ -10592,16 +10612,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n"
|
||||
" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n"
|
||||
" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n"
|
||||
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" return %none : !torch.none\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
|
|
|
@ -99,6 +99,12 @@ def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_sc
|
|||
def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]:
|
||||
return (input_rank_dtype[1], torch.int64)
|
||||
|
||||
def torchvision〇nms〡shape(dets: List[int], scores: List[int], iou_threshold: float) -> List[int]:
|
||||
return [hacky_get_unknown_dimension_size(), len(dets)]
|
||||
|
||||
def torchvision〇nms〡dtype(dets_rank_dtype: Tuple[int, int], scores_rank_dtype: Tuple[int, int], iou_threshold: float) -> int:
|
||||
return torch.int
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
|
||||
|
|
|
@ -1197,6 +1197,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit("torchvision::nms : (Tensor, Tensor, float) -> (Tensor)")
|
||||
|
||||
|
||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||
|
|
|
@ -1805,3 +1805,108 @@ func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtens
|
|||
}
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_nonmaxsuppression_identical_boxes(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,10,4],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,10],f32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
|
||||
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
|
||||
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4],f32>, %arg1: !torch.vtensor<[1,1,10],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.vtensor<[10,4],f32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
|
||||
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float
|
||||
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
|
||||
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1,3],si64>
|
||||
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64>
|
||||
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_32:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
|
||||
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
|
||||
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
|
||||
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
|
||||
return %0 : !torch.vtensor<[1,3],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_nonmaxsuppression_single_box(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
|
||||
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
|
||||
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""}
|
||||
func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float
|
||||
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
|
||||
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1,3],si64>
|
||||
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64>
|
||||
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_32:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
|
||||
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
|
||||
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
|
||||
// CHECK: }
|
||||
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
|
||||
return %0 : !torch.vtensor<[1,3],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue