Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / detection_output.hpp
index 8d3d75c..87ea568 100644 (file)
@@ -18,6 +18,7 @@
 #pragma once
 #include <limits>
 #include "../C/detection_output.h"
+#include "../C/detection_output_sort.h"
 #include "primitive.hpp"
 
 namespace cldnn
@@ -39,7 +40,7 @@ enum class prior_box_code_type : int32_t
 
 /// @brief Generates a list of detections based on location and confidence predictions by doing non maximum suppression.
 /// @details Each row is a 7 dimension vector, which stores: [image_id, label, confidence, xmin, ymin, xmax, ymax].
-/// If number of detections per image is lower than keep_top_k, will write dummy results at the end with image_id=-1. 
+/// If number of detections per image is lower than keep_top_k, will write dummy results at the end with image_id=-1.
 struct detection_output : public primitive_base<detection_output, CLDNN_PRIMITIVE_DESC(detection_output)>
 {
     CLDNN_DECLARE_PRIMITIVE(detection_output)
@@ -80,7 +81,8 @@ struct detection_output : public primitive_base<detection_output, CLDNN_PRIMITIV
         const int32_t input_width = -1,
         const int32_t input_height = -1,
         const bool decrease_label_id = false,
-        const bool clip = false,
+        const bool clip_before_nms = false,
+        const bool clip_after_nms = false,
         const padding& output_padding = padding()
         )
         : primitive_base(id, { input_location, input_confidence, input_prior_box }, output_padding)
@@ -100,7 +102,8 @@ struct detection_output : public primitive_base<detection_output, CLDNN_PRIMITIV
         , input_width(input_width)
         , input_height(input_height)
         , decrease_label_id(decrease_label_id)
-        , clip(clip)
+        , clip_before_nms(clip_before_nms)
+        , clip_after_nms(clip_after_nms)
     {
         if (decrease_label_id && background_label_id != 0)
             throw std::invalid_argument("Cannot use decrease_label_id and background_label_id parameter simultaneously.");
@@ -125,7 +128,8 @@ struct detection_output : public primitive_base<detection_output, CLDNN_PRIMITIV
         , input_width(dto->input_width)
         , input_height(dto->input_height)
         , decrease_label_id(dto->decrease_label_id != 0)
-        , clip(dto->clip != 0)
+        , clip_before_nms(dto->clip_before_nms != 0)
+        , clip_after_nms(dto->clip_after_nms != 0)
     {
         if (decrease_label_id && background_label_id != 0)
             throw std::invalid_argument("Cannot use decrease_label_id and background_label_id parameter simultaneously.");
@@ -163,8 +167,10 @@ struct detection_output : public primitive_base<detection_output, CLDNN_PRIMITIV
     const int32_t input_height;
     /// @brief Decrease label id to skip background label equal to 0. Can't be used simultaneously with background_label_id.
     const bool decrease_label_id;
-    /// @brief Clip decoded boxes
-    const bool clip;
+    /// @brief Clip decoded boxes right after decoding
+    const bool clip_before_nms;
+    /// @brief Clip decoded boxes after nms step
+    const bool clip_after_nms;
 
 protected:
     void update_dto(dto& dto) const override
@@ -185,7 +191,81 @@ protected:
         dto.input_width = input_width;
         dto.input_height = input_height;
         dto.decrease_label_id = decrease_label_id;
-        dto.clip = clip;
+        dto.clip_before_nms = clip_before_nms;
+        dto.clip_after_nms = clip_after_nms;
+    }
+};
+
+/// @brief Generates a list of detections based on location and confidence predictions by doing non maximum suppression.
+/// @details Each row is a 7 dimension vector, which stores: [image_id, label, confidence, xmin, ymin, xmax, ymax].
+/// If number of detections per image is lower than keep_top_k, will write dummy results at the end with image_id=-1.
+struct detection_output_sort : public primitive_base<detection_output_sort, CLDNN_PRIMITIVE_DESC(detection_output_sort)>
+{
+    CLDNN_DECLARE_PRIMITIVE(detection_output_sort)
+
+    /// @brief Constructs detection output primitive.
+    /// @param id This primitive id.
+    /// @param input_bboxes Input bounding boxes primitive id.
+    /// @param num_images Number of images to be predicted.
+    /// @param num_classes Number of classes to be predicted.
+    /// @param keep_top_k Number of total bounding boxes to be kept per image after NMS step.
+    /// @param share_location If true bounding box are shared among different classes.
+    /// @param top_k Maximum number of results to be kept in NMS.
+    /// @param output_padding Output padding.
+    detection_output_sort(
+        const primitive_id& id,
+        const primitive_id& input_bboxes,
+        const uint32_t num_images,
+        const uint32_t num_classes,
+        const uint32_t keep_top_k,
+        const bool share_location = true,
+        const int top_k = -1,
+        const int background_label_id = -1,
+        const padding& output_padding = padding()
+    )
+    : primitive_base(id, { input_bboxes }, output_padding)
+    , num_images(num_images)
+    , num_classes(num_classes)
+    , keep_top_k(keep_top_k)
+    , share_location(share_location)
+    , top_k(top_k)
+    , background_label_id(background_label_id)
+    {}
+
+    /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{detection_output}
+    detection_output_sort(const dto* dto)
+        : primitive_base(dto)
+        , num_images(dto->num_images)
+        , num_classes(dto->num_classes)
+        , keep_top_k(dto->keep_top_k)
+        , share_location(dto->share_location != 0)
+        , top_k(dto->top_k)
+        , background_label_id(dto->background_label_id)
+    {}
+
+    /// @brief Number of classes to be predicted.
+    const uint32_t num_images;
+    /// @brief Number of classes to be predicted.
+    const uint32_t num_classes;
+    /// @brief Number of total bounding boxes to be kept per image after NMS step.
+    const int keep_top_k;
+    /// @brief If true, bounding box are shared among different classes.
+    const bool share_location;
+    /// @brief Maximum number of results to be kept in NMS.
+    const int top_k;
+    /// @brief Background label id (-1 if there is no background class).
+    const int background_label_id;
+
+
+protected:
+    void update_dto(dto& dto) const override
+    {
+        dto.num_classes = num_classes;
+        dto.num_images = num_images;
+        dto.keep_top_k = keep_top_k;
+        dto.share_location = share_location;
+        dto.top_k = top_k;
+        dto.background_label_id = background_label_id;
     }
 };
 /// @}