Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / contract / contract_kernel_base.cpp
1 // Copyright (c) 2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "contract_kernel_base.h"
17
18 #include "kernel_selector_utils.h"
19
20
21 namespace kernel_selector
22 {
23     JitConstants ContractKernelBase::GetJitConstants(const contract_params& params)
24     {
25         JitConstants jit = MakeBaseParamsJitConstants(params);
26
27         const size_t no_dim_flag = 6;
28         std::vector<size_t> output_dims(4, no_dim_flag);
29         int out_dim = 2;
30         for (int i = 3; i >= 0; --i)
31         {
32             if (std::find(params.reduction_axes.begin(), params.reduction_axes.end(), i) == params.reduction_axes.end())
33                 output_dims.at(i) = out_dim--;
34         }
35
36         if (output_dims[3] != no_dim_flag)
37             jit.AddConstants({
38                 MakeJitConstant("DIM_X", output_dims.at(3))
39             });
40         if (output_dims[2] != no_dim_flag)
41             jit.AddConstants({
42                 MakeJitConstant("DIM_Y", output_dims.at(2))
43             });
44         if (output_dims[1] != no_dim_flag)
45             jit.AddConstants({
46                 MakeJitConstant("DIM_F", output_dims.at(1))
47             });
48         if (output_dims[0] != no_dim_flag)
49             jit.AddConstants({
50                 MakeJitConstant("DIM_B", output_dims.at(0))
51             });
52
53         jit.AddConstants({
54             MakeJitConstant("REDUCE_X", output_dims.at(3) == no_dim_flag),
55             MakeJitConstant("REDUCE_Y", output_dims.at(2) == no_dim_flag),
56             MakeJitConstant("REDUCE_F", output_dims.at(1) == no_dim_flag),
57             MakeJitConstant("REDUCE_B", output_dims.at(0) == no_dim_flag)
58         });
59
60         switch (params.mode)
61         {
62         case ContractMode::SUM:
63             jit.AddConstants({
64                 MakeJitConstant("REDUCE_SEED", "0"),
65                 MakeJitConstant("REDUCE_OPERATION(a, b)", "a + b")
66             });
67             break;
68         case ContractMode::PRODUCT:
69             jit.AddConstants({
70                 MakeJitConstant("REDUCE_SEED", "1"),
71                 MakeJitConstant("REDUCE_OPERATION(a, b)", "a * b")
72             });
73             break;
74         case ContractMode::ALL:
75             jit.AddConstants({
76                 MakeJitConstant("REDUCE_SEED", "1"),
77                 MakeJitConstant("REDUCE_OPERATION(a, b)", "a && b")
78             });
79             break;
80         case ContractMode::ANY:
81             jit.AddConstants({
82                 MakeJitConstant("REDUCE_SEED", "0"),
83                 MakeJitConstant("REDUCE_OPERATION(a, b)", "a || b")
84             });
85             break;
86         case ContractMode::MAX:
87             jit.AddConstants({
88                 MakeJitConstant("REDUCE_SEED", "UNIT_VAL_MIN"),
89                 MakeJitConstant("REDUCE_OPERATION(a, b)", "UNIT_MAX_FUNC(a,b)")
90             });
91             break;
92         }
93
94         return jit;
95     }
96
97     ContractKernelBase::DispatchData ContractKernelBase::SetDefault(const contract_params& params)
98     {
99         const auto& output = params.output;
100
101         DispatchData kd;
102
103         kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
104
105         std::vector<size_t> global{ output.Feature().v, output.Y().v, output.X().v };
106         const auto& local = GetOptimalLocalWorkGroupSizes(global);
107
108         kd.gws0 = global[0];
109         kd.gws1 = global[1];
110         kd.gws2 = global[2];
111
112         kd.lws0 = local[0];
113         kd.lws1 = local[1];
114         kd.lws2 = local[2];
115
116         return kd;
117     }
118
119     KernelsData ContractKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
120     {
121         assert(params.GetType() == KernelType::CONTRACT);
122
123         const auto& prim_params = static_cast<const contract_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
124
125         auto run_info = SetDefault(prim_params);
126         KernelData k_data = KernelData::Default<contract_params>(params);
127
128         auto cldnn_jit = GetJitConstants(prim_params);
129         auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
130         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
131
132         auto& kernel = k_data.kernels[0];
133         FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point);
134         k_data.estimatedTime = estimated_time;
135
136         return{ k_data };
137     }
138 }