1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "pyramid_roi_align_kernel_base.h"
16 #include "kernel_selector_utils.h"
18 namespace kernel_selector {
20 JitConstants PyramidROIAlignKernelBase::GetJitConstants(const PyramidROIAlign_params& params)
22 JitConstants jit = MakeBaseParamsJitConstants(params);
26 PyramidROIAlignKernelBase::DispatchData PyramidROIAlignKernelBase::SetDefault(const PyramidROIAlign_params& params)
28 const auto& boxes = params.inputs.at(0);
31 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
33 std::vector<size_t> global;
34 global = { boxes.Y().v, 1, 1 };
36 const auto& local = GetOptimalLocalWorkGroupSizes(global);
49 KernelsData PyramidROIAlignKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
51 assert(params.GetType() == KernelType::PYRAMID_ROI_ALIGN);
53 const auto& prim_params = static_cast<const PyramidROIAlign_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
54 auto run_info = SetDefault(prim_params);
55 KernelData k_data = KernelData::Default<PyramidROIAlign_params>(params);
56 auto cldnn_jit = GetJitConstants(prim_params);
57 auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
58 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
60 auto& kernel = k_data.kernels[0];
61 FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point, "", false, false, (uint32_t)prim_params.inputs.size());
63 k_data.estimatedTime = estimated_time;