2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "detection_output_kernel_base.h"
19 namespace kernel_selector
21 JitConstants DetectionOutputKernelBase::GetJitConstants(const detection_output_params & params) const
23 JitConstants jit = MakeBaseParamsJitConstants(params);
25 const auto& detectOutParams = params.detectOutParams;
28 MakeJitConstant("NUM_IMAGES", detectOutParams.num_images),
29 MakeJitConstant("NUM_CLASSES", detectOutParams.num_classes),
30 MakeJitConstant("KEEP_TOP_K", detectOutParams.keep_top_k),
31 MakeJitConstant("TOP_K", detectOutParams.top_k),
32 MakeJitConstant("BACKGROUND_LABEL_ID", detectOutParams.background_label_id),
33 MakeJitConstant("CODE_TYPE", detectOutParams.code_type),
34 MakeJitConstant("CONF_SIZE_X", detectOutParams.conf_size_x),
35 MakeJitConstant("CONF_SIZE_Y", detectOutParams.conf_size_y),
36 MakeJitConstant("CONF_PADDING_X", detectOutParams.conf_padding_x),
37 MakeJitConstant("CONF_PADDING_Y", detectOutParams.conf_padding_y),
38 MakeJitConstant("SHARE_LOCATION", detectOutParams.share_location),
39 MakeJitConstant("VARIANCE_ENCODED_IN_TARGET", detectOutParams.variance_encoded_in_target),
40 MakeJitConstant("NMS_THRESHOLD", detectOutParams.nms_threshold),
41 MakeJitConstant("ETA", detectOutParams.eta),
42 MakeJitConstant("CONFIDENCE_THRESHOLD", detectOutParams.confidence_threshold),
43 MakeJitConstant("IMAGE_WIDTH", detectOutParams.input_width),
44 MakeJitConstant("IMAGE_HEIGH", detectOutParams.input_heigh),
45 MakeJitConstant("ELEMENTS_PER_THREAD", detectOutParams.elements_per_thread),
46 MakeJitConstant("PRIOR_COORD_OFFSET", detectOutParams.prior_coordinates_offset),
47 MakeJitConstant("PRIOR_INFO_SIZE", detectOutParams.prior_info_size),
48 MakeJitConstant("PRIOR_IS_NORMALIZED", detectOutParams.prior_is_normalized),
54 DetectionOutputKernelBase::DispatchData DetectionOutputKernelBase::SetDefault(const detection_output_params& params) const
58 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;