2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "detection_output_kernel_sort.h"
18 #include "kernel_selector_utils.h"
20 #define DETECTION_OUTPUT_ROW_SIZE 7 // Each detection consists of [image_id, label, confidence, xmin, ymin, xmax, ymax].
22 namespace kernel_selector
25 ParamsKey DetectionOutputKernel_sort::GetSupportedKey() const
28 k.EnableInputDataType(Datatype::F16);
29 k.EnableInputDataType(Datatype::F32);
30 k.EnableOutputDataType(Datatype::F16);
31 k.EnableOutputDataType(Datatype::F32);
32 k.EnableInputLayout(DataLayout::bfyx);
33 k.EnableOutputLayout(DataLayout::bfyx);
34 k.EnableTensorOffset();
35 k.EnableTensorPitches();
40 CommonDispatchData DetectionOutputKernel_sort::SetDefault(const detection_output_params& params) const
42 CommonDispatchData runInfo = DetectionOutputKernelBase::SetDefault(params);
44 unsigned class_num = params.detectOutParams.num_classes;
45 if (params.detectOutParams.share_location && params.detectOutParams.background_label_id == 0)
49 const size_t bboxesNum = class_num * params.detectOutParams.num_images;
50 // Work group size is set to number of bounding boxes per image
51 size_t work_group_size = class_num;
53 if (work_group_size > 256)
55 work_group_size = (work_group_size + work_group_size % 2) / (work_group_size / 256 + 1);
58 runInfo.gws0 = Align(bboxesNum, work_group_size);
62 runInfo.lws0 = work_group_size;
69 KernelsData DetectionOutputKernel_sort::GetKernelsData(const Params& params, const optional_params& options) const
71 assert(params.GetType() == KernelType::DETECTION_OUTPUT &&
72 options.GetType() == KernelType::DETECTION_OUTPUT);
74 KernelData kd = KernelData::Default<detection_output_params>(params);
75 const detection_output_params& detectOutParams = static_cast<const detection_output_params&>(params);
76 DispatchData runInfo = SetDefault(detectOutParams);
78 auto cldnnJit = GetJitConstants(detectOutParams);
79 auto entryPoint = GetEntryPoint(kernelName, detectOutParams.layerID, options);
80 auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
82 auto& kernel = kd.kernels[0];
83 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entryPoint);
85 kd.estimatedTime = FORCE_PRIORITY_8;