1 // Copyright (c) 2016 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "pooling_kernel_base.h"
19 namespace kernel_selector {
20 bool PoolingKernelBase::Validate(const Params& p, const optional_params& o) const {
21 if (p.GetType() != KernelType::POOLING || o.GetType() != KernelType::POOLING) {
28 JitConstants PoolingKernelBase::GetJitConstants(const pooling_params& pp, PoolingKernelBase::DispatchData kd) const {
29 JitConstants mem_consts = MakeBaseParamsJitConstants(pp);
31 mem_consts.AddConstants({
32 MakeJitConstant("POOL", pp.poolSize),
33 MakeJitConstant("STRIDE", pp.poolStride),
34 MakeJitConstant("PADDING", pp.poolPad),
35 MakeJitConstant(toString(pp.poolType) + "_POOLING", 1),
36 MakeJitConstant(toString(pp.divMode) + "_KERNEL_DIVIDER", 1),
39 if (kd.needsBoundary) {
40 mem_consts.AddConstant(MakeJitConstant("CHECK_BOUNDRY", 1));
46 // Checks if we need boundary checking in kernel.
47 bool PoolingKernelBase::NeedsBoundaryCheck(const pooling_params& pp) const {
48 if (pp.poolPad.x != 0 || pp.poolPad.y != 0 || pp.poolPad.z != 0) {
52 const auto& input = pp.inputs[0];
54 if (input.X().v < pp.poolSize.x || input.Y().v < pp.poolSize.y || input.Z().v < pp.poolSize.z) {
58 if (pp.poolSize.x < 3 || pp.poolSize.y < 3) {
62 auto mod_x = (input.X().v - pp.poolSize.x) % pp.poolStride.x;
63 auto mod_y = (input.Y().v - pp.poolSize.y) % pp.poolStride.y;
64 auto mod_z = (input.Z().v - pp.poolSize.z) % pp.poolStride.z;
66 return mod_x || mod_y || mod_z;
69 PoolingKernelBase::DispatchData PoolingKernelBase::SetDefault(const pooling_params& params) const {
70 const auto& output = params.output;
74 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
76 if (output.GetLayout() == DataLayout::bfyx || output.GetLayout() == DataLayout::byxf ||
77 output.GetLayout() == DataLayout::bfzyx) {
78 // Determine global work sizes.
79 kd.gws2 = output.Batch().v * output.Feature().v; // B, F
80 kd.gws0 = Align(output.X().v, 32); // X
81 kd.gws1 = output.Y().v * output.Z().v; // Y, Z
83 // Find largest positive local work size that is divider for global work size.
88 // Determine global work sizes.
89 kd.gws0 = output.Batch().v * output.Feature().v; // B, F
90 kd.gws1 = output.X().v; // X
91 kd.gws2 = output.Y().v; // Y
93 kd.lws0 = std::min(std::max(kd.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
94 while (kd.gws0 % kd.lws0 != 0) {
101 kd.needsBoundary = NeedsBoundaryCheck(params);
106 KernelsData PoolingKernelBase::GetCommonKernelsData(const Params& params,
107 const optional_params& options,
108 float estimatedTime) const {
109 if (!Validate(params, options)) {
113 const pooling_params& orgParams = static_cast<const pooling_params&>(params);
115 DispatchData runInfo = SetDefault(orgParams);
117 KernelData kd = KernelData::Default<pooling_params>(params);
119 auto cldnn_jit = GetJitConstants(orgParams, runInfo);
120 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
121 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
123 auto& kernel = kd.kernels[0];
124 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
125 if (orgParams.poolType == PoolType::MAX_WITH_ARGMAX)
126 kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
128 kd.estimatedTime = estimatedTime;
132 } // namespace kernel_selector