ce172f764971d86748d75626266e8d96f000b3c8
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_base.cpp
1 // Copyright (c) 2016-2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "eltwise_kernel_base.h"
16 #include "kernel_selector_utils.h"
17 #include <string>
18 #include <vector>
19
20 namespace kernel_selector {
21 static uint32_t GetNumberOfInputs(EltwiseMode m) {
22     switch (m) {
23         case EltwiseMode::ADD:
24         case EltwiseMode::SUB:
25         case EltwiseMode::MUL:
26         case EltwiseMode::DIV:
27         case EltwiseMode::MIN:
28         case EltwiseMode::MAX:
29         case EltwiseMode::POW:
30         case EltwiseMode::MODULU:
31         case EltwiseMode::EQ:
32         case EltwiseMode::NE:
33         case EltwiseMode::LT:
34         case EltwiseMode::LE:
35         case EltwiseMode::GT:
36         case EltwiseMode::GE:
37         case EltwiseMode::LOGIC_AND:
38         case EltwiseMode::LOGIC_OR:
39         case EltwiseMode::LOGIC_XOR:
40         case EltwiseMode::SQUARED_DIFF:
41         case EltwiseMode::FLOOR_MOD:
42             return 2;
43         case EltwiseMode::SQRT:
44         case EltwiseMode::RSQRT:
45         case EltwiseMode::ASSIGN:
46             return 1;
47         default:
48             return 0;
49     }
50 }
51
52 ParamsKey eltwise_params::GetParamsKey() const {
53     ParamsKey k = base_params::GetParamsKey();
54     if (int8_quantization) {
55         k.EnableInt8Quantization();
56     }
57
58     if (output_calibration) {
59         k.EnableOutputCalibration();
60     }
61
62     if (inputs_calibration) {
63         k.EnableEltwiseInputsCalibration();
64     }
65
66     if (!stride.empty()) {
67         k.EnableEltwiseStride();
68     }
69
70     if (broadcast) {
71         k.EnableEltwiseBroadcast();
72     }
73
74     return k;
75 }
76
77 bool EltwiseKernelBase::Validate(const Params& p, const optional_params& o) const {
78     if (p.GetType() != KernelType::ELTWISE || o.GetType() != KernelType::ELTWISE) {
79         return false;
80     }
81
82     const eltwise_params& params = static_cast<const eltwise_params&>(p);
83
84     if (params.inputs.size() == 0) {
85         return false;
86     }
87
88     auto& operations = params.operations;
89
90     if (operations.size() == 0) {
91         return false;
92     }
93
94     for (size_t op_num = 0; op_num < operations.size(); op_num++) {
95         const auto& ew = operations[op_num];
96
97         if (ew.inputs.size() != GetNumberOfInputs(ew.mode)) {
98             return false;
99         }
100
101         for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
102             const auto& input = ew.inputs[input_idx];
103             if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index >= params.inputs.size()) {
104                 return false;
105             }
106         }
107     }
108
109     return true;
110 }
111
112 JitConstants EltwiseKernelBase::GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const {
113     JitConstants jit = MakeBaseParamsJitConstants(params);
114
115     auto GetIdxOrderForLayout = [&](DataLayout l, bool layoutBased, uSize stride) -> std::string {
116         // TODO: Generalize this method
117         std::vector<std::string> bfyx_idx_order = {};
118         if (layoutBased) {
119             bfyx_idx_order = { "d4", "d3", "d2", "d1" };
120         } else {
121             if (l == DataLayout::yxfb) {
122                 bfyx_idx_order = { "d1", "d2", "d3", "d4" };
123             } else if (l == DataLayout::fyxb) {
124                 bfyx_idx_order = { "d1", "d4", "d3", "d2" };
125             } else {
126                 bfyx_idx_order = { "d4", "d3", "d2", "d1" };
127             }
128         }
129
130         if (!params.stride.empty()) {
131             bfyx_idx_order[2] = "(" + bfyx_idx_order[2] + "*" + std::to_string(stride.y) + ")";
132             bfyx_idx_order[3] = "(" + bfyx_idx_order[3] + "*" + std::to_string(stride.x) + ")";
133         }
134
135         return bfyx_idx_order[0] + "," +
136                bfyx_idx_order[1] + "," +
137                bfyx_idx_order[2] + "," +
138                bfyx_idx_order[3];
139     };
140
141     jit.AddConstants({
142         MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
143         MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
144         MakeJitConstant("ELTWISE_BROADCAST", params.broadcast),
145     });
146
147     if (params.int8_quantization) {
148         if (params.output_calibration) {
149             jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
150             jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
151
152         } else {
153             jit.AddConstants({MakeJitConstant("O_QF", params.output_quantization_factor)});
154         }
155     }
156
157     std::string inputs_decls, vload_decls;
158     auto& updateInputs = params.updateInputIds;
159
160     std::string out_idx_order = "OUTPUT_IDX_ORDER";
161     uSize out_stride = {1, 1, 1};
162     if (useVload8) {
163         jit.AddConstant(MakeJitConstant(out_idx_order, "d1"));
164     } else {
165         if (CheckInputsOutputNoPitchSameDims(params) &&
166             !(params.layoutBased || params.int8_quantization || params.broadcast)) {
167             jit.AddConstant(MakeJitConstant(out_idx_order, "d1"));
168         } else {
169             size_t out_c = DataTensor::ChannelsCount(params.output.GetLayout());
170             if (out_c <= 4) {
171                 jit.AddConstant(MakeJitConstant(out_idx_order, GetIdxOrderForLayout(params.output.GetLayout(),
172                                                                                     params.layoutBased || params.broadcast,
173                                                                                     out_stride)));
174             } else if (out_c == 5) {
175                 jit.AddConstant(MakeJitConstant(out_idx_order, "d5,d4,d3,d2,d1"));
176             } else {
177                 assert(0);
178             }
179         }
180     }
181     if (!params.stride.empty()) {
182         jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_X", out_stride.x));
183         jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_Y", out_stride.y));
184         jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_Z", out_stride.z));
185     }
186
187     for (size_t i = 0; i < params.inputs.size(); i++) {
188         // const should be added only to inputs which will not be updated
189         std::string const_str = "const";
190         for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++) {
191             if (updateInputs[update_input_idx].inputId == i) {
192                 const_str = "";
193                 break;
194             }
195         }
196
197         inputs_decls +=
198             const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
199         if (!params.stride.empty()) {
200             jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
201             jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
202             jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Z", params.stride[i].z));
203         }
204         std::string idx_order = "INPUT" + std::to_string(i) + "_IDX_ORDER";
205         if (useVload8) {
206             vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
207             if (params.inputs[i].PhysicalSize() == 1)  // Scalar case
208                 vload_decls += " = (" + toCLType(params.inputs[i].GetDType()) + "8)(input" + std::to_string(i) + "[0]";
209             else  // Buffer case
210                 vload_decls += " = vload8(global_id, input" + std::to_string(i);
211             vload_decls += ");";
212             jit.AddConstant(MakeJitConstant(idx_order, "d1"));
213         } else {
214             if (CheckInputsOutputNoPitchSameDims(params) &&
215                 !(params.layoutBased || params.int8_quantization || params.broadcast)) {
216                 jit.AddConstant(MakeJitConstant(idx_order, "d1"));
217             } else {
218                 size_t in_c = DataTensor::ChannelsCount(params.inputs[i].GetLayout());
219                 size_t out_c = DataTensor::ChannelsCount(params.output.GetLayout());
220                 auto in_stride = params.stride.empty() ? out_stride : params.stride[i];
221                 if (out_c <= 4 && in_c <= 4) {
222                     jit.AddConstant(MakeJitConstant(idx_order, GetIdxOrderForLayout(params.inputs[i].GetLayout(),
223                                                                                     params.layoutBased || params.broadcast,
224                                                                                     in_stride)));
225                 } else if (out_c == 5) {
226                     if (in_c < 5) {
227                         // Skip Z coord for 4d tensors
228                         jit.AddConstant(MakeJitConstant(idx_order, "d5,d4,d2,d1"));
229                     } else if (in_c == 5) {
230                         jit.AddConstant(MakeJitConstant(idx_order, "d5,d4,d3,d2,d1"));
231                     }
232                 } else if (out_c <= 4 && in_c == 5) {
233                     // quite strange case, but can happen due to reorders fusing
234                     // it means that z coord is equal to 1, so z offset will be always equal to 0
235                     jit.AddConstant(MakeJitConstant(idx_order, "d4,d3,0,d2,d1"));
236                 } else {
237                     assert(0);
238                 }
239             }
240         }
241     }
242
243     jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
244     jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
245
246     if (useVload8)
247         jit.AddConstant(MakeJitConstant("VLOAD_DECLS", vload_decls));
248
249     std::string do_eltwise;
250
251     auto& operations = params.operations;
252     auto& coefficients = params.coefficients;
253
254     for (size_t op_num = 0; op_num < operations.size(); op_num++) {
255         const std::string op_num_str = std::to_string(op_num);
256         const auto& ew = operations[op_num];
257
258         for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
259             const auto& input = ew.inputs[input_idx];
260             const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
261             std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
262
263             switch (input.mode) {
264                 case EltwiseInputMode::SCALAR:
265                     jit.AddConstant(MakeJitConstant(name, input.scalar));
266                     break;
267                 case EltwiseInputMode::INPUT_BUFFER:
268                     if (useVload8)
269                         jit.AddConstant(MakeJitConstant(name, "in" + std::to_string(input.index)));
270                     else
271                         jit.AddConstant(MakeJitConstant(name,
272                                                         "input" + std::to_string(input.index) +
273                                                         "[GET_INDEX(INPUT, " + std::to_string(input.index) +
274                                                         "," + idx_order +")]"));
275                     break;
276                 case EltwiseInputMode::OUTPUT_BUFFER:
277                     jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT,,"+ out_idx_order +")]"));
278                     break;
279                 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
280                     jit.AddConstant(MakeJitConstant(
281                         name,
282                         "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
283                     break;
284                 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
285                     jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
286                     break;
287                 default:
288                     break;
289             }
290         }
291
292         std::string input0_str, input1_str, cast_type, output_cast, op;
293
294         if (useVload8) {
295             cast_type = "(MAKE_VECTOR_TYPE(UNIT_TYPE, 8))";
296             op = "const MAKE_VECTOR_TYPE(UNIT_TYPE, 8) tmp" + op_num_str + " = ";
297         } else if (params.int8_quantization) {
298             cast_type = "(int)";
299             op = "const int tmp" + op_num_str + " = ";
300         } else {
301             cast_type = "(UNIT_TYPE)";
302             op = "const UNIT_TYPE tmp" + op_num_str + " = ";
303         }
304
305         if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
306             output_cast = "(char)";
307             cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
308         }
309
310         input0_str = cast_type + "INPUT_" + op_num_str + "_0";
311         input1_str = cast_type + "INPUT_" + op_num_str + "_1";
312
313         if (ew.mode == EltwiseMode::ADD) {
314             std::vector<std::string> coeff_strings(ew.inputs.size(), "");
315             for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
316                 const auto& input = ew.inputs[input_idx];
317                 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size()) {
318                     const float c = coefficients[input.index];
319                     if (c != 1.0f)
320                         coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
321                 }
322             }
323
324             input0_str = coeff_strings[0] + input0_str;
325             input1_str = coeff_strings[1] + input1_str;
326         }
327
328         switch (ew.mode) {
329             case EltwiseMode::ADD:
330                 op += input0_str + " + " + input1_str;
331                 break;
332             case EltwiseMode::SUB:
333                 op += input0_str + " - " + input1_str;
334                 break;
335             case EltwiseMode::MUL:
336                 op += input0_str + " * " + input1_str;
337                 break;
338             case EltwiseMode::DIV:
339                 op += input0_str + " / " + input1_str;
340                 break;
341             case EltwiseMode::MODULU:
342             case EltwiseMode::MIN:
343             case EltwiseMode::MAX: {
344                 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
345                 auto input_0_type = params.inputs[0].GetDType();
346                 auto input_1_type = params.inputs[1].GetDType();
347
348                 // input_0 == int
349                 if (input_0_type == kernel_selector::Datatype::INT8 ||
350                     input_0_type == kernel_selector::Datatype::INT32 ||
351                     input_0_type == kernel_selector::Datatype::INT64) {
352                     // input_0 == int && input_1 == int
353                     if (input_1_type == kernel_selector::Datatype::INT8 ||
354                         input_1_type == kernel_selector::Datatype::INT32 ||
355                         input_1_type == kernel_selector::Datatype::INT64) {
356                         if (ew.mode == EltwiseMode::MODULU)
357                             op += input0_str + " % " + input1_str;
358                         else
359                             op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
360                     } else {
361                     // input_0 == int && input_1 != int
362                         op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
363                     }
364                 } else if (input_1_type == kernel_selector::Datatype::INT8 ||
365                          input_1_type == kernel_selector::Datatype::INT32 ||
366                          input_1_type == kernel_selector::Datatype::INT64) {
367                     // input_0 != int && input_1 == int
368                     op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
369                 } else {
370                     // input_0 != int && input_1 != int
371                     op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
372                 }
373             } break;
374             case EltwiseMode::POW:
375                 op += cast_type + "pow(" + input0_str + ", " + input1_str + ")";
376                 break;
377             case EltwiseMode::SQRT:
378                 op += cast_type + "sqrt(" + input0_str + ")";
379                 break;
380             case EltwiseMode::RSQRT:
381                 op += cast_type + "1/sqrt(" + input0_str + ")";
382                 break;
383             case EltwiseMode::SQUARED_DIFF:
384                 op += cast_type + "((" + input0_str + " - " + input1_str +
385                       ")"
386                       " * (" +
387                       input0_str + " - " + input1_str + "))";
388                 break;
389             case EltwiseMode::EQ:
390                 op += output_cast + "(" + input0_str + " == " + input1_str + ")";
391                 break;
392             case EltwiseMode::NE:
393                 op += output_cast + "(" + input0_str + " != " + input1_str + ")";
394                 break;
395             case EltwiseMode::LT:
396                 op += output_cast + "(" + input0_str + " < " + input1_str + ")";
397                 break;
398             case EltwiseMode::LE:
399                 op += output_cast + "(" + input0_str + " <= " + input1_str + ")";
400                 break;
401             case EltwiseMode::GT:
402                 op += output_cast + "(" + input0_str + " > " + input1_str + ")";
403                 break;
404             case EltwiseMode::GE:
405                 op += output_cast + "(" + input0_str + " >= " + input1_str + ")";
406                 break;
407             case EltwiseMode::LOGIC_AND:
408                 op += output_cast + "(" + input0_str + " && " + input1_str + ")";
409                 break;
410             case EltwiseMode::LOGIC_OR:
411                 op += output_cast + "(" + input0_str + " || " + input1_str + ")";
412                 break;
413             case EltwiseMode::LOGIC_XOR:
414                 op += output_cast + "(!" + input0_str + " != !" + input1_str + ")";
415                 break;
416             case EltwiseMode::FLOOR_MOD:
417                 op += output_cast + "(" + input0_str + " - " + input0_str + " / " + input1_str + " * " + input1_str + ")";
418                 break;
419             case EltwiseMode::ASSIGN:
420                 op += input0_str;
421                 break;
422             default:
423                 break;
424         }
425
426         std::string opname = "OPERATION" + op_num_str;
427         jit.AddConstant(MakeJitConstant(opname, op));
428         do_eltwise += "\\\n\t" + opname + ";";
429     }
430
431     for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
432         do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) + "[GET_INDEX(INPUT, " +
433                       std::to_string(updateInputs[update_input_idx].inputId) + ", " +
434                       "INPUT"+std::to_string(updateInputs[update_input_idx].inputId) + "_IDX_ORDER)] = tmp" +
435                       std::to_string(updateInputs[update_input_idx].tmpId) + ";";
436
437     do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
438
439     jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
440
441     if (params.layoutBased || params.int8_quantization || params.broadcast) {
442         jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
443     }
444
445     if (!params.stride.empty()) {
446         jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
447     }
448
449     return jit;
450 }
451
452 JitConstants EltwiseKernelBase::GetJitConstants(const eltwise_params& params) const {
453     return GetJitConstantsCommon(params, false);
454 }
455
456 EltwiseKernelBase::DispatchData EltwiseKernelBase::SetDefault(const eltwise_params& params) const {
457     DispatchData kd;
458
459     if (params.layoutBased || params.int8_quantization || params.broadcast) {
460         auto global = GetTensorFriendlyWorkGroups(params.output);
461         kd.gws0 = global[0];
462         kd.gws1 = global[1];
463         kd.gws2 = global[2];
464     } else if (CheckInputsOutputNoPitchSameDims(params)) {
465         kd.gws0 = params.output.LogicalSize();
466         kd.gws1 = 1;
467         kd.gws2 = 1;
468     } else {
469         const auto& out = params.output;
470
471         std::vector<size_t> gws;
472         for (const auto& o : out.GetDims()) {
473             gws.push_back(o.v);
474         }
475
476         size_t n_dims;
477         if (out.GetLayout() == DataLayout::bfzyx)
478             n_dims = 5;
479         else
480             n_dims = 4;
481
482         for (size_t i = gws.size(); i < n_dims; i++) {
483             gws.push_back(1U);
484         }
485
486         kd.gws0 = gws[0];
487         if (n_dims == 5) {
488             kd.gws1 = gws[1] * gws[2];  // y*z
489             kd.gws2 = gws[3] * gws[4];
490         } else {
491             kd.gws1 = gws[1];
492             kd.gws2 = gws[2] * gws[3];
493         }
494     }
495
496     auto local = GetOptimalLocalWorkGroupSizes({kd.gws0, kd.gws1, kd.gws2});
497
498     if (params.output.GetLayout() == DataLayout::bfyx_f16 && params.output.Feature().v % 16 == 0 &&
499         kd.gws1 % 16 == 0) {
500         kd.lws0 = 1;
501         kd.lws1 = 16;
502         kd.lws2 = 1;
503     } else {
504         kd.lws0 = local[0];
505         kd.lws1 = local[1];
506         kd.lws2 = local[2];
507     }
508
509     return kd;
510 }
511
512 KernelsData EltwiseKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const {
513     if (!Validate(params, options)) {
514         return {};
515     }
516
517     KernelData kd = KernelData::Default<eltwise_params>(params);
518     eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
519
520     auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
521     auto cldnn_jit = GetJitConstants(newParams);
522     std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
523
524     DispatchData runInfo = SetDefault(newParams);
525
526     auto& kernel = kd.kernels[0];
527
528     kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
529     kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
530
531     kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
532     kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(),
533                                    false,
534                                    false,
535                                    newParams.int8_quantization,
536                                    newParams.output_calibration);
537
538     kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
539
540     return {kd};
541 }
542 }  // namespace kernel_selector