2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "detection_output_kernel_ref.h"
18 #include "kernel_selector_utils.h"
20 #define PRIOR_BOX_SIZE 4 // Each prior-box consists of [xmin, ymin, xmax, ymax].
22 namespace kernel_selector
25 ParamsKey DetectionOutputKernel::GetSupportedKey() const
28 k.EnableInputDataType(Datatype::F16);
29 k.EnableInputDataType(Datatype::F32);
30 k.EnableOutputDataType(Datatype::F16);
31 k.EnableOutputDataType(Datatype::F32);
32 k.EnableInputLayout(DataLayout::bfyx);
33 k.EnableOutputLayout(DataLayout::bfyx);
34 k.EnableTensorOffset();
35 k.EnableTensorPitches();
40 CommonDispatchData DetectionOutputKernel::SetDefault(const detection_output_params& params) const
42 CommonDispatchData runInfo = DetectionOutputKernelBase::SetDefault(params);
44 // Number of all work items is set to total number of bounding boxes -
45 // one bounding box is procerssed by one work item
46 size_t num_classes = (params.detectOutParams.share_location)? 1 : params.detectOutParams.num_classes;
48 // Size of input0 (input location), if shared loaction it is equal to size of one class,
49 // else it has size of all items for all classes
50 size_t bboxesNum = params.inputs[0].LogicalSize() / PRIOR_BOX_SIZE / num_classes;
51 // Work group size is set to number of bounding boxes per image for sorting purpose
52 // (access to one table with sorted values)
53 size_t work_group_size = bboxesNum / params.inputs[0].Batch().v;
55 if (work_group_size > 256)
57 work_group_size = work_group_size / ((work_group_size / 256) + 1) + 1;
60 bboxesNum = work_group_size * params.inputs[0].Batch().v;
62 runInfo.gws0 = Align(bboxesNum, work_group_size);
66 runInfo.lws0 = work_group_size;
73 KernelsData DetectionOutputKernel::GetKernelsData(const Params& params, const optional_params& options) const
75 assert(params.GetType() == KernelType::DETECTION_OUTPUT &&
76 options.GetType() == KernelType::DETECTION_OUTPUT);
78 KernelData kd = KernelData::Default<detection_output_params>(params);
79 const detection_output_params& detectOutParams = static_cast<const detection_output_params&>(params);
80 DispatchData runInfo = SetDefault(detectOutParams);
82 auto cldnnJit = GetJitConstants(detectOutParams);
83 auto entryPoint = GetEntryPoint(kernelName, detectOutParams.layerID, options);
84 auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
86 auto& kernel = kd.kernels[0];
87 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entryPoint);
88 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
89 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 2 });
91 kd.estimatedTime = FORCE_PRIORITY_8;