Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / detection_output / detection_output_kernel_ref.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "detection_output_kernel_ref.h"
18 #include "kernel_selector_utils.h"
19
20 #define PRIOR_BOX_SIZE 4 // Each prior-box consists of [xmin, ymin, xmax, ymax].
21
22 namespace kernel_selector
23 {
24
25     ParamsKey DetectionOutputKernel::GetSupportedKey() const
26     {
27         ParamsKey k;
28         k.EnableInputDataType(Datatype::F16);
29         k.EnableInputDataType(Datatype::F32);
30         k.EnableOutputDataType(Datatype::F16);
31         k.EnableOutputDataType(Datatype::F32);
32         k.EnableInputLayout(DataLayout::bfyx);
33         k.EnableOutputLayout(DataLayout::bfyx);
34         k.EnableTensorOffset();
35         k.EnableTensorPitches();
36         k.EnableBatching();
37         return k;
38     }
39
40     CommonDispatchData DetectionOutputKernel::SetDefault(const detection_output_params& params) const
41     {
42         CommonDispatchData runInfo = DetectionOutputKernelBase::SetDefault(params);
43
44         // Number of all work items is set to total number of bounding boxes -
45         // one bounding box is procerssed by one work item
46         size_t num_classes = (params.detectOutParams.share_location)? 1 : params.detectOutParams.num_classes;
47
48         // Size of input0 (input location), if shared loaction it is equal to size of one class, 
49         // else it has size of all items for all classes
50         size_t bboxesNum = params.inputs[0].LogicalSize() / PRIOR_BOX_SIZE / num_classes;
51         // Work group size is set to number of bounding boxes per image for sorting purpose
52         // (access to one table with sorted values)
53         size_t work_group_size = bboxesNum / params.inputs[0].Batch().v;
54
55         if (work_group_size > 256)
56         {
57             work_group_size = work_group_size / ((work_group_size / 256) + 1) + 1;
58         }
59
60         bboxesNum = work_group_size * params.inputs[0].Batch().v;
61
62         runInfo.gws0 = Align(bboxesNum, work_group_size);
63         runInfo.gws1 = 1;
64         runInfo.gws2 = 1;
65
66         runInfo.lws0 = work_group_size;
67         runInfo.lws1 = 1;
68         runInfo.lws2 = 1;
69
70         return runInfo;
71     }
72
73     KernelsData DetectionOutputKernel::GetKernelsData(const Params& params, const optional_params& options) const
74     {
75         assert(params.GetType() == KernelType::DETECTION_OUTPUT &&
76                options.GetType() == KernelType::DETECTION_OUTPUT);
77
78         KernelData kd = KernelData::Default<detection_output_params>(params);
79         const detection_output_params& detectOutParams = static_cast<const detection_output_params&>(params);
80         DispatchData runInfo = SetDefault(detectOutParams);
81
82         auto cldnnJit = GetJitConstants(detectOutParams);
83         auto entryPoint = GetEntryPoint(kernelName, detectOutParams.layerID, options);
84         auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
85
86         auto& kernel = kd.kernels[0];
87         FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entryPoint);
88         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
89         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 2 });
90
91         kd.estimatedTime = FORCE_PRIORITY_8;
92
93         return{ kd };
94     }
95 }