560e78ae557f071ed870c3755af83c7f669d1325
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / pooling / pooling_kernel_base.cpp
1 // Copyright (c) 2016 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "pooling_kernel_base.h"
17 #include <algorithm>
18
19 namespace kernel_selector {
20 bool PoolingKernelBase::Validate(const Params& p, const optional_params& o) const {
21     if (p.GetType() != KernelType::POOLING || o.GetType() != KernelType::POOLING) {
22         return false;
23     }
24
25     return true;
26 }
27
28 JitConstants PoolingKernelBase::GetJitConstants(const pooling_params& pp, PoolingKernelBase::DispatchData kd) const {
29     JitConstants mem_consts = MakeBaseParamsJitConstants(pp);
30
31     mem_consts.AddConstants({
32         MakeJitConstant("POOL", pp.poolSize),
33         MakeJitConstant("STRIDE", pp.poolStride),
34         MakeJitConstant("PADDING", pp.poolPad),
35         MakeJitConstant(toString(pp.poolType) + "_POOLING", 1),
36         MakeJitConstant(toString(pp.divMode) + "_KERNEL_DIVIDER", 1),
37     });
38
39     if (kd.needsBoundary) {
40         mem_consts.AddConstant(MakeJitConstant("CHECK_BOUNDRY", 1));
41     }
42
43     return mem_consts;
44 }
45
46 // Checks if we need boundary checking in kernel.
47 bool PoolingKernelBase::NeedsBoundaryCheck(const pooling_params& pp) const {
48     if (pp.poolPad.x != 0 || pp.poolPad.y != 0 || pp.poolPad.z != 0) {
49         return true;
50     }
51
52     const auto& input = pp.inputs[0];
53
54     if (input.X().v < pp.poolSize.x || input.Y().v < pp.poolSize.y || input.Z().v < pp.poolSize.z) {
55         return true;
56     }
57
58     if (pp.poolSize.x < 3 || pp.poolSize.y < 3) {
59         return true;
60     }
61
62     auto mod_x = (input.X().v - pp.poolSize.x) % pp.poolStride.x;
63     auto mod_y = (input.Y().v - pp.poolSize.y) % pp.poolStride.y;
64     auto mod_z = (input.Z().v - pp.poolSize.z) % pp.poolStride.z;
65
66     return mod_x || mod_y || mod_z;
67 }
68
69 PoolingKernelBase::DispatchData PoolingKernelBase::SetDefault(const pooling_params& params) const {
70     const auto& output = params.output;
71
72     DispatchData kd;
73
74     kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
75
76     if (output.GetLayout() == DataLayout::bfyx || output.GetLayout() == DataLayout::byxf ||
77         output.GetLayout() == DataLayout::bfzyx) {
78         // Determine global work sizes.
79         kd.gws2 = output.Batch().v * output.Feature().v;  // B, F
80         kd.gws0 = Align(output.X().v, 32);                // X
81         kd.gws1 = output.Y().v * output.Z().v;            // Y, Z
82
83         // Find largest positive local work size that is divider for global work size.
84         kd.lws0 = 32;
85         kd.lws1 = 1;
86         kd.lws2 = 1;
87     } else {
88         // Determine global work sizes.
89         kd.gws0 = output.Batch().v * output.Feature().v;  // B, F
90         kd.gws1 = output.X().v;                           // X
91         kd.gws2 = output.Y().v;                           // Y
92
93         kd.lws0 = std::min(std::max(kd.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
94         while (kd.gws0 % kd.lws0 != 0) {
95             --kd.lws0;
96         }
97         kd.lws1 = 1;
98         kd.lws2 = 1;
99     }
100
101     kd.needsBoundary = NeedsBoundaryCheck(params);
102
103     return kd;
104 }
105
106 KernelsData PoolingKernelBase::GetCommonKernelsData(const Params& params,
107                                                     const optional_params& options,
108                                                     float estimatedTime) const {
109     if (!Validate(params, options)) {
110         return {};
111     }
112
113     const pooling_params& orgParams = static_cast<const pooling_params&>(params);
114
115     DispatchData runInfo = SetDefault(orgParams);
116
117     KernelData kd = KernelData::Default<pooling_params>(params);
118
119     auto cldnn_jit = GetJitConstants(orgParams, runInfo);
120     auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
121     auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
122
123     auto& kernel = kd.kernels[0];
124     FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
125     if (orgParams.poolType == PoolType::MAX_WITH_ARGMAX)
126         kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
127
128     kd.estimatedTime = estimatedTime;
129
130     return {kd};
131 }
132 }  // namespace kernel_selector