Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / detection_output / detection_output_kernel_base.h
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #pragma once
18
19 #include "common_kernel_base.h"
20 #include "kernel_selector_params.h"
21
22 namespace kernel_selector
23 {
24     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
25     // detection_output_params
26     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
27     struct detection_output_params : public base_params
28     {
29         detection_output_params() : base_params(KernelType::DETECTION_OUTPUT), detectOutParams() {}
30
31         struct DedicatedParams
32         {
33             uint32_t num_images;
34             uint32_t num_classes;
35             int32_t keep_top_k;
36             int32_t top_k;
37             int32_t background_label_id;
38             int32_t code_type;
39             int32_t conf_size_x;
40             int32_t conf_size_y;
41             int32_t conf_padding_x;
42             int32_t conf_padding_y;
43             int32_t elements_per_thread;
44             int32_t input_width;
45             int32_t input_heigh;
46             int32_t prior_coordinates_offset;
47             int32_t prior_info_size;
48             bool prior_is_normalized;
49             bool share_location;
50             bool variance_encoded_in_target;
51             float nms_threshold;
52             float eta;
53             float confidence_threshold;
54         };
55
56         DedicatedParams detectOutParams;
57
58         virtual ParamsKey GetParamsKey() const
59         {
60             return base_params::GetParamsKey();
61         }
62     };
63
64     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
65     // detection_output_optional_params
66     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
67     struct detection_output_optional_params : optional_params
68     {
69         detection_output_optional_params() : optional_params(KernelType::DETECTION_OUTPUT) {}
70     };
71
72     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
73     // DetectionOutputKernelBase
74     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
75     class DetectionOutputKernelBase : public common_kernel_base
76     {
77     public:
78         using common_kernel_base :: common_kernel_base;
79         virtual ~DetectionOutputKernelBase() {}
80
81         using DispatchData = CommonDispatchData;
82     
83     protected:
84         JitConstants GetJitConstants(const detection_output_params& params) const;
85         virtual DispatchData SetDefault(const detection_output_params& params) const;
86     };
87 }