Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_b_fs_yx_fsv4.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "eltwise_kernel_b_fs_yx_fsv4.h"
18 #include "kernel_selector_utils.h" 
19
20 namespace kernel_selector {
21
22     ParamsKey EltwiseKernel_b_fs_yx_fsv4::GetSupportedKey() const
23     {
24         ParamsKey k;
25         k.EnableInputDataType(Datatype::INT8);
26         k.EnableInputDataType(Datatype::UINT8);
27         k.EnableOutputDataType(Datatype::INT8);
28         k.EnableOutputDataType(Datatype::UINT8);
29         k.EnableInputLayout(DataLayout::b_fs_yx_fsv4);
30         k.EnableOutputLayout(DataLayout::b_fs_yx_fsv4);
31         k.EnableTensorOffset();
32         k.EnableTensorPitches();
33         k.EnableBatching();
34         k.EnableInt8Quantization();
35         k.EnableOutputCalibration();
36         k.EnableEltwiseStride();
37         return k;
38     }
39
40     EltwiseKernelBase::DispatchData EltwiseKernel_b_fs_yx_fsv4::SetDefault(const eltwise_params& params) const
41     {
42         DispatchData kd;
43
44         // Because of very specific requirements for data, we may linearize the data,
45         // i.e. use only one dimension, e.g. 'X'.
46
47         //GWS:
48         // we process 4*4 (4 int8 bytes per on block_read4 reading) features per workitem
49         kd.gws0 = params.output.X().v * params.output.Y().v *
50                   params.output.Batch().v * params.output.Feature().v / (4*4);
51         kd.gws1 = 1;
52         kd.gws2 = 1;
53         // LWS:
54         kd.lws0 = 8;
55         kd.lws1 = 1;
56         kd.lws2 = 1;
57
58         kd.effiency = FORCE_PRIORITY_1;
59         return kd;
60     }
61
62     bool EltwiseKernel_b_fs_yx_fsv4::Validate(const Params& params, const optional_params& options) const
63     {
64         // Requirents to use 'eltwise_b_fs_yx_fsv4' kernel are below:
65         // 1. No stride
66         // 2. All dimensions for all inputs are the same
67         // 3. No padding
68         // So, it can be linearized
69
70         if (!Parent::Validate(params, options)) {
71             return false;
72         }
73
74         KernelData kd = KernelData::Default<eltwise_params>(params);
75         eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
76
77         // 1. No stride
78         if (!newParams.stride.empty()) {
79             return false;
80         }
81
82         for (size_t i = 0; i < newParams.inputs.size() - 1; i++)
83         {
84             // 2. All dimensions for all inputs are the same
85             if (!(newParams.inputs[i] == newParams.inputs[i + 1])) {
86                 return false;
87             }
88         }
89
90         const auto& in = newParams.inputs[0];
91         for (size_t i = 0; i < in.Dimentions(); i++)
92         {
93             // 3. No padding
94             if ((in.GetDims()[i].pad.before != 0) ||
95                 (in.GetDims()[i].pad.after != 0)) {
96                 return false;
97             }
98         }
99
100         return true;
101     }
102
103     JitConstants EltwiseKernel_b_fs_yx_fsv4::GetJitConstants(const eltwise_params& params) const
104     {
105         JitConstants jit = MakeBaseParamsJitConstants(params);
106
107         if (params.inputs[0].GetDType() == Datatype::UINT8) {
108             // Special handler for unsigned types
109             jit.AddConstants({
110                 MakeJitConstant("ELTW_UNSIGNED", 1)
111             });
112         }
113
114         ///////////////
115         jit.AddConstants({
116             MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
117             MakeJitConstant("QUANTIZATION_TERM",    params.int8_quantization),
118         });
119
120         if (params.int8_quantization)
121         {
122             if (params.output_calibration)
123             {
124                 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
125                 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
126
127             }
128             else
129                 jit.AddConstants({ MakeJitConstant("O_QF",       params.output_quantization_factor) });
130         }
131
132         std::string inputs_decls;
133         auto& updateInputs = params.updateInputIds;
134
135         for (size_t i = 0; i < params.inputs.size(); i++)
136         {
137             //const should be added only to inputs which will not be updated
138             std::string const_str = "const";
139             for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
140             {
141                 if (updateInputs[update_input_idx].inputId == i)
142                 {
143                     const_str = "";
144                     break;
145                 }
146             }
147
148             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
149         }
150
151         jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
152         jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
153
154         std::string do_eltwise;
155
156         auto& operations = params.operations;
157         auto& coefficients = params.coefficients;
158
159         for (size_t op_num = 0; op_num < operations.size(); op_num++)
160         {
161             const std::string op_num_str = std::to_string(op_num);
162             const auto& ew = operations[op_num];
163
164             for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
165             {
166                 const auto& input = ew.inputs[input_idx];
167                 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
168                 switch (input.mode)
169                 {
170                 case EltwiseInputMode::SCALAR:
171                     jit.AddConstant(MakeJitConstant(name, input.scalar));
172                     break;
173                 case EltwiseInputMode::INPUT_BUFFER:
174                     jit.AddConstant(MakeJitConstant(name, "GET_INPUT(input" + std::to_string(input.index) + ", INPUT" + std::to_string(input.index) + ")"));
175                     break;
176                 case EltwiseInputMode::OUTPUT_BUFFER:
177                     jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
178                     break;
179                 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
180                     jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
181                     break;
182                 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
183                     jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
184                     break;
185                 default:
186                     break;
187                 }
188             }
189             std::string input0_str, input1_str, cast_type, op;
190
191             cast_type = "(int16)";
192             op = "const int16 tmp" + op_num_str + " = ";
193
194             input0_str = cast_type + "INPUT_" + op_num_str + "_0";
195             input1_str = cast_type + "INPUT_" + op_num_str + "_1";
196
197             if (ew.mode == EltwiseMode::ADD)
198             {
199                 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
200                 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
201                 {
202                     const auto& input = ew.inputs[input_idx];
203                     if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
204                     {
205                         const float c = coefficients[input.index];
206                         if (c != 1.0f)
207                             coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
208                     }
209                 }
210
211                 input0_str = coeff_strings[0] + input0_str;
212                 input1_str = coeff_strings[1] + input1_str;
213             }
214
215
216             switch (ew.mode)
217             {
218             case EltwiseMode::ADD:      op += input0_str + " + " + input1_str; break;
219             case EltwiseMode::SUB:      op += input0_str + " - " + input1_str; break;
220             case EltwiseMode::MUL:      op += input0_str + " * " + input1_str; break;
221             case EltwiseMode::DIV:      op += input0_str + " / " + input1_str; break;
222             case EltwiseMode::MODULU:
223             case EltwiseMode::MIN:
224             case EltwiseMode::MAX:
225             {
226                 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
227                 auto input_0_type = params.inputs[0].GetDType();
228                 auto input_1_type = params.inputs[1].GetDType();
229
230                 // input_0 == int
231                 if (input_0_type == kernel_selector::Datatype::INT8 ||
232                     input_0_type == kernel_selector::Datatype::UINT8)
233                 {
234                     // input_0 == int && input_1 == int
235                     if (input_1_type == kernel_selector::Datatype::INT8 ||
236                         input_1_type == kernel_selector::Datatype::UINT8)
237                     {
238                         if (ew.mode == EltwiseMode::MODULU)
239                             op += input0_str + " % " + input1_str;
240                         else
241                             op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
242                     }
243                     // input_0 == int && input_1 != int
244                     else
245                     {
246                         op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
247                     }
248                 }
249                 // input_0 != int && input_1 == int
250                 else if (input_1_type == kernel_selector::Datatype::INT8 ||
251                          input_1_type == kernel_selector::Datatype::UINT8)
252                 {
253                     op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
254                 }
255                 // input_0 != int && input_1 != int
256                 else
257                 {
258                     op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
259                 }
260             } break;
261             case EltwiseMode::POW:      op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
262             case EltwiseMode::SQRT:     op += cast_type + "sqrt(" + input0_str + ")"; break;
263             case EltwiseMode::RSQRT:    op += cast_type + "1/sqrt(" + input0_str + ")"; break;
264             case EltwiseMode::ASSIGN:   op += input0_str; break;
265             default:
266                 break;
267             }
268
269             std::string opname = "OPERATION" + op_num_str;
270             jit.AddConstant(MakeJitConstant(opname, op));
271             do_eltwise += "\\\n\t" + opname + ";";
272         }
273
274         for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
275             do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
276             "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
277             ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
278
279         do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
280
281         jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
282
283         if (params.layoutBased || params.int8_quantization)
284         {
285             jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
286         }
287
288         if (!params.stride.empty())
289         {
290             jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
291         }
292
293         ///////////////
294         return jit;
295     }
296
297     KernelsData EltwiseKernel_b_fs_yx_fsv4::GetKernelsData(const Params& params, const optional_params& options) const
298     {
299         return GetCommonKernelsData(params, options);
300     }
301 }