Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_base.cpp
1 /*
2 // Copyright (c) 2016-2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "eltwise_kernel_base.h"
18 #include "kernel_selector_utils.h"
19
20 namespace kernel_selector
21 {
22     static uint32_t GetNumberOfInputs(EltwiseMode m)
23     {
24         switch (m)
25         {
26         case EltwiseMode::ADD:
27         case EltwiseMode::SUB:
28         case EltwiseMode::MUL:
29         case EltwiseMode::DIV:
30         case EltwiseMode::MIN:
31         case EltwiseMode::MAX:
32         case EltwiseMode::POW:
33         case EltwiseMode::MODULU:
34         case EltwiseMode::EQ:
35         case EltwiseMode::NE:
36         case EltwiseMode::LT:
37         case EltwiseMode::LE:
38         case EltwiseMode::GT:
39         case EltwiseMode::GE:
40         case EltwiseMode::LOGIC_AND:
41         case EltwiseMode::LOGIC_OR:
42         case EltwiseMode::LOGIC_XOR:
43         case EltwiseMode::SQUARED_DIFF:
44             return 2;
45         case EltwiseMode::SQRT:
46         case EltwiseMode::RSQRT:
47         case EltwiseMode::ASSIGN:
48             return 1;
49         default:
50             return 0;
51         }
52     }
53
54     ParamsKey eltwise_params::GetParamsKey() const
55     {
56         ParamsKey k = base_params::GetParamsKey();
57         if (int8_quantization)
58         {
59             k.EnableInt8Quantization();
60         }
61
62         if (output_calibration)
63         {
64             k.EnableOutputCalibration();
65         }
66
67         if (!stride.empty())
68         {
69             k.EnableEltwiseStride();
70         }
71
72         if (broadcast)
73         {
74             k.EnableEltwiseBroadcast();
75         }
76
77         return k;
78     }
79
80     bool EltwiseKernelBase::Validate(const Params& p, const optional_params& o) const
81     {
82         if (p.GetType() != KernelType::ELTWISE ||
83             o.GetType() != KernelType::ELTWISE)
84         {
85             return false;
86         }
87
88         const eltwise_params& params = static_cast<const eltwise_params&>(p);
89
90         if (params.inputs.size() == 0)
91         {
92             return false;
93         }
94
95         auto& operations = params.operations;
96
97         if (operations.size() == 0)
98         {
99             return false;
100         }
101
102         for (size_t op_num = 0; op_num < operations.size(); op_num++)
103         {
104             const auto& ew = operations[op_num];
105
106             if (ew.inputs.size() != GetNumberOfInputs(ew.mode))
107             {
108                 return false;
109             }
110
111             for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
112             {
113                 const auto& input = ew.inputs[input_idx];
114                 if (input.mode == EltwiseInputMode::INPUT_BUFFER &&
115                     input.index >= params.inputs.size())
116                 {
117                     return false;
118                 }
119             }
120         }
121
122         return true;
123     }
124
125     JitConstants EltwiseKernelBase::GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const
126     {
127         JitConstants jit = MakeBaseParamsJitConstants(params);
128
129         jit.AddConstants({
130             MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
131             MakeJitConstant("QUANTIZATION_TERM",    params.int8_quantization),
132             MakeJitConstant("ELTWISE_BROADCAST",    params.broadcast),
133         });
134
135         if (params.int8_quantization)
136         {
137             if (params.output_calibration)
138             {
139                 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
140                 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
141
142             }
143             else
144                 jit.AddConstants({ MakeJitConstant("O_QF",       params.output_quantization_factor) });
145         }
146
147         std::string inputs_decls, vload_decls;
148         auto& updateInputs = params.updateInputIds;
149
150         for (size_t i = 0; i < params.inputs.size(); i++)
151         {
152             //const should be added only to inputs which will not be updated
153             std::string const_str = "const";
154             for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
155             {
156                 if (updateInputs[update_input_idx].inputId == i)
157                 {
158                     const_str = "";
159                     break;
160                 }
161             }
162
163             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
164             if (!params.stride.empty())
165             {
166                 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
167                 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
168             }
169             if (useVload8)
170             {
171                 vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
172                 if (params.inputs[i].PhysicalSize() == 1) //Scalar case
173                     vload_decls += " = (" + toCLType(params.inputs[i].GetDType()) + "8)(input" + std::to_string(i) + "[0]";
174                 else //Buffer case
175                     vload_decls += " = vload8(global_id, input" + std::to_string(i);
176                 vload_decls += ");";
177             }
178         }
179
180         jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
181         jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
182
183         if (useVload8)
184             jit.AddConstant(MakeJitConstant("VLOAD_DECLS", vload_decls));
185
186         std::string do_eltwise;
187
188         auto& operations   = params.operations;
189         auto& coefficients = params.coefficients;
190
191         for (size_t op_num = 0; op_num < operations.size(); op_num++)
192         {
193             const std::string op_num_str = std::to_string(op_num);
194             const auto& ew = operations[op_num];
195
196             for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
197             {
198                 const auto& input = ew.inputs[input_idx];
199                 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
200                 switch (input.mode)
201                 {
202                 case EltwiseInputMode::SCALAR:
203                     jit.AddConstant(MakeJitConstant(name, input.scalar));
204                     break;
205                 case EltwiseInputMode::INPUT_BUFFER:
206                     if(useVload8)
207                         jit.AddConstant(MakeJitConstant(name, "in" + std::to_string(input.index)));
208                     else
209                         jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[GET_INDEX(INPUT, " + std::to_string(input.index) + ")]"));
210                     break;
211                 case EltwiseInputMode::OUTPUT_BUFFER:
212                     jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
213                     break;
214                 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
215                     jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
216                     break;
217                 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
218                     jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
219                     break;
220                 default:
221                     break;
222                 }
223             }
224
225             std::string input0_str, input1_str, cast_type, output_cast, op;
226
227             if (useVload8)
228             {
229                 cast_type = "(MAKE_VECTOR_TYPE(UNIT_TYPE, 8))";
230                 op = "const MAKE_VECTOR_TYPE(UNIT_TYPE, 8) tmp" + op_num_str + " = ";
231             }
232             else if(params.int8_quantization)
233             {
234                 cast_type = "(int)";
235                 op = "const int tmp" + op_num_str + " = ";
236             }
237             else
238             {
239                 cast_type = "(UNIT_TYPE)";
240                 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
241             }
242
243             if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
244                 output_cast = "(char)";
245                 cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
246             }
247
248             input0_str = cast_type + "INPUT_" + op_num_str + "_0";
249             input1_str = cast_type + "INPUT_" + op_num_str + "_1";
250
251             if (ew.mode == EltwiseMode::ADD)
252             {
253                 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
254                 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
255                 {
256                     const auto& input = ew.inputs[input_idx];
257                     if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
258                     {
259                         const float c = coefficients[input.index];
260                         if (c != 1.0f)
261                             coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
262                     }
263                 }
264
265                 input0_str = coeff_strings[0] + input0_str;
266                 input1_str = coeff_strings[1] + input1_str;
267             }
268
269
270             switch (ew.mode)
271             {
272             case EltwiseMode::ADD:         op += input0_str + " + " + input1_str; break;
273             case EltwiseMode::SUB:         op += input0_str + " - " + input1_str; break;
274             case EltwiseMode::MUL:         op += input0_str + " * " + input1_str; break;
275             case EltwiseMode::DIV:         op += input0_str + " / " + input1_str; break;
276             case EltwiseMode::MODULU:
277             case EltwiseMode::MIN:
278             case EltwiseMode::MAX:
279             {
280                 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max" ));
281                 auto input_0_type = params.inputs[0].GetDType();
282                 auto input_1_type = params.inputs[1].GetDType();
283
284                 // input_0 == int
285                 if (input_0_type == kernel_selector::Datatype::INT8 ||
286                     input_0_type == kernel_selector::Datatype::INT32 ||
287                     input_0_type == kernel_selector::Datatype::INT64)
288                 {
289                     // input_0 == int && input_1 == int
290                     if (input_1_type == kernel_selector::Datatype::INT8 ||
291                         input_1_type == kernel_selector::Datatype::INT32 ||
292                         input_1_type == kernel_selector::Datatype::INT64)
293                     {
294                         if (ew.mode == EltwiseMode::MODULU)
295                             op += input0_str + " % " + input1_str;
296                         else
297                             op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
298                     }
299                     // input_0 == int && input_1 != int
300                     else
301                     {
302                         op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
303                     }
304                 }
305                 // input_0 != int && input_1 == int
306                 else if ( input_1_type == kernel_selector::Datatype::INT8 ||
307                           input_1_type == kernel_selector::Datatype::INT32 ||
308                           input_1_type == kernel_selector::Datatype::INT64)
309                 {
310                     op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
311                 }
312                 // input_0 != int && input_1 != int
313                 else
314                 {
315                     op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
316                 }
317             } break;
318             case EltwiseMode::POW:          op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
319             case EltwiseMode::SQRT:         op += cast_type + "sqrt(" + input0_str + ")"; break;
320             case EltwiseMode::RSQRT:        op += cast_type + "1/sqrt(" + input0_str + ")"; break;
321             case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
322                                                             " * (" + input0_str + " - " + input1_str + "))"; break;
323             case EltwiseMode::EQ:           op += output_cast + "(" + input0_str + " == " + input1_str + ")"; break;
324             case EltwiseMode::NE:           op += output_cast + "(" + input0_str + " != " + input1_str + ")"; break;
325             case EltwiseMode::LT:           op += output_cast + "(" + input0_str + " < " + input1_str + ")"; break;
326             case EltwiseMode::LE:           op += output_cast + "(" + input0_str + " <= " + input1_str + ")"; break;
327             case EltwiseMode::GT:           op += output_cast + "(" + input0_str + " > " + input1_str + ")"; break;
328             case EltwiseMode::GE:           op += output_cast + "(" + input0_str + " >= " + input1_str + ")"; break;
329             case EltwiseMode::LOGIC_AND:    op += output_cast + "(" + input0_str + " && " + input1_str + ")"; break;
330             case EltwiseMode::LOGIC_OR:     op += output_cast + "(" + input0_str + " || " + input1_str + ")"; break;
331             case EltwiseMode::LOGIC_XOR:    op += output_cast + "(!" + input0_str + " != !" + input1_str + ")"; break;
332             case EltwiseMode::ASSIGN:       op += input0_str; break;
333             default:
334                 break;
335             }
336
337             std::string opname = "OPERATION" + op_num_str;
338             jit.AddConstant(MakeJitConstant(opname, op));
339             do_eltwise += "\\\n\t" + opname + ";";
340         }
341
342         for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
343             do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
344             "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
345             ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
346
347         do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
348
349         jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
350
351         if (params.layoutBased || params.int8_quantization || params.broadcast)
352         {
353             jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
354         }
355
356         if (!params.stride.empty())
357         {
358             jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
359         }
360
361         return jit;
362     }
363
364     JitConstants EltwiseKernelBase::GetJitConstants(const eltwise_params& params) const
365     {
366         return GetJitConstantsCommon(params, false);
367     }
368
369     EltwiseKernelBase::DispatchData EltwiseKernelBase::SetDefault(const eltwise_params& params) const
370     {
371         DispatchData kd;
372
373         if (params.layoutBased || params.int8_quantization || params.broadcast)
374         {
375             auto global = GetTensorFriendlyWorkGroups(params.output);
376             kd.gws0 = global[0];
377             kd.gws1 = global[1];
378             kd.gws2 = global[2];
379             if (!params.stride.empty())
380             {
381                 kd.gws0 /= params.stride[0].x;
382                 kd.gws0 /= params.stride[0].y;
383             }
384         }
385         else if (CheckInputsOutputNoPitchSameDims(params))
386         {
387             kd.gws0 = params.inputs[0].LogicalSize();
388             kd.gws1 = 1;
389             kd.gws2 = 1;
390         }
391         else
392         {
393             const auto& out = params.output;
394
395             std::vector<size_t> gws;
396             for (const auto& o : out.GetDims())
397             {
398                 gws.push_back(o.v);
399             }
400
401             for (size_t i = gws.size(); i < 4; i++)
402             {
403                 gws.push_back(1U);
404             }
405
406             kd.gws0 = gws[0];
407             kd.gws1 = gws[1];
408             kd.gws2 = gws[2] * gws[3];
409         }
410
411         auto local = GetOptimalLocalWorkGroupSizes( { kd.gws0, kd.gws1, kd.gws2 } );
412         kd.lws0 = local[0];
413         kd.lws1 = local[1];
414         kd.lws2 = local[2];
415
416         return kd;
417     }
418
419     KernelsData EltwiseKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
420     {
421         if (!Validate(params, options))
422         {
423             return{};
424         }
425
426         KernelData kd = KernelData::Default<eltwise_params>(params);
427         eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
428
429         auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
430         auto cldnn_jit = GetJitConstants(newParams);
431         std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
432
433         DispatchData runInfo = SetDefault(newParams);
434
435         auto& kernel = kd.kernels[0];
436
437         kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
438         kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
439
440         kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
441         kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false, newParams.int8_quantization, newParams.output_calibration);
442
443         kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
444
445         return{ kd };
446     }
447 }