Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_fs_bs_yx_bsv4_fsv32.cpp
1 /*
2 // Copyright (c) 2018-2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "eltwise_kernel_fs_bs_yx_bsv4_fsv32.h"
18 #include "kernel_selector_utils.h"
19
20 namespace kernel_selector {
21
22     ParamsKey EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetSupportedKey() const
23     {
24         ParamsKey k;
25         k.EnableInputDataType(Datatype::INT8);
26         k.EnableOutputDataType(Datatype::INT8);
27         k.EnableInputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
28         k.EnableOutputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
29         k.EnableTensorOffset();
30         k.EnableTensorPitches();
31         k.EnableBatching();
32         k.EnableInt8Quantization();
33         k.EnableOutputCalibration();
34         k.EnableEltwiseStride();
35         return k;
36     }
37
38     EltwiseKernelBase::DispatchData EltwiseKernel_fs_bs_yx_bsv4_fsv32::SetDefault(const eltwise_params& params) const
39     {
40         DispatchData kd;
41
42         kd.gws0 = params.output.X().v;
43         kd.gws1 = params.output.Y().v;
44         // we process 4 batches and 4 features per workitem
45         kd.gws2 = (params.output.Batch().v / 4) * (params.output.Feature().v / 4);
46         kd.lws0 = 1;
47         kd.lws1 = 1;
48         kd.lws2 = 8;
49
50         kd.effiency = FORCE_PRIORITY_3;
51         return kd;
52     }
53
54     JitConstants EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetJitConstants(const eltwise_params& params) const
55     {
56         JitConstants jit = MakeBaseParamsJitConstants(params);
57
58         const size_t in_x_pitch = 32 * 4;
59         const size_t in_y_pitch = 32 * 4 * params.inputs[0].X().LogicalDimPadded();
60         const size_t in_b_block_pitch = in_y_pitch * params.inputs[0].Y().LogicalDimPadded();
61         const size_t in_f_block_pitch = in_b_block_pitch * ((params.inputs[0].Batch().v + 3) / 4);
62         const size_t in_offset = in_x_pitch * params.inputs[0].X().pad.before + in_y_pitch * params.inputs[0].Y().pad.before;
63
64         jit.AddConstant(MakeJitConstant("IN_X_PITCH", in_x_pitch));
65         jit.AddConstant(MakeJitConstant("IN_Y_PITCH", in_y_pitch));
66         jit.AddConstant(MakeJitConstant("IN_B_BLOCK_PITCH", in_b_block_pitch));
67         jit.AddConstant(MakeJitConstant("IN_F_BLOCK_PITCH", in_f_block_pitch));
68         jit.AddConstant(MakeJitConstant("IN_OFFSET", in_offset));
69
70         ///////////////
71         jit.AddConstants({
72             MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
73             MakeJitConstant("QUANTIZATION_TERM",    params.int8_quantization),
74             });
75
76         if (params.int8_quantization)
77         {
78             if (params.output_calibration)
79             {
80                 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
81                 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
82
83             }
84             else
85                 jit.AddConstants({ MakeJitConstant("O_QF",       params.output_quantization_factor) });
86         }
87
88         std::string inputs_decls;
89         auto& updateInputs = params.updateInputIds;
90
91         for (size_t i = 0; i < params.inputs.size(); i++)
92         {
93             //const should be added only to inputs which will not be updated
94             std::string const_str = "const";
95             for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
96             {
97                 if (updateInputs[update_input_idx].inputId == i)
98                 {
99                     const_str = "";
100                     break;
101                 }
102             }
103
104             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
105
106             if (!params.stride.empty())
107             {
108                 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
109                 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
110             }
111         }
112
113         jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
114         jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
115
116         std::string do_eltwise;
117
118         auto& operations = params.operations;
119         auto& coefficients = params.coefficients;
120
121         for (size_t op_num = 0; op_num < operations.size(); op_num++)
122         {
123             const std::string op_num_str = std::to_string(op_num);
124             const auto& ew = operations[op_num];
125
126             for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
127             {
128                 const auto& input = ew.inputs[input_idx];
129                 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
130                 switch (input.mode)
131                 {
132                 case EltwiseInputMode::SCALAR:
133                     jit.AddConstant(MakeJitConstant(name, input.scalar));
134                     break;
135                 case EltwiseInputMode::INPUT_BUFFER:
136                     jit.AddConstant(MakeJitConstant(name, "GET_INPUT(input" + std::to_string(input.index) + ", INPUT" + std::to_string(input.index) + ")"));
137                     break;
138                 case EltwiseInputMode::OUTPUT_BUFFER:
139                     jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
140                     break;
141                 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
142                     jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
143                     break;
144                 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
145                     jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
146                     break;
147                 default:
148                     break;
149                 }
150             }
151             std::string input0_str, input1_str, cast_type, op;
152
153             if (params.int8_quantization)
154             {
155                 cast_type = "(int16)";
156                 op = "const int16 tmp" + op_num_str + " = ";
157             }
158             else
159             {
160                 cast_type = "(UNIT_TYPE)";
161                 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
162             }
163
164             input0_str = cast_type + "INPUT_" + op_num_str + "_0";
165             input1_str = cast_type + "INPUT_" + op_num_str + "_1";
166
167             if (ew.mode == EltwiseMode::ADD)
168             {
169                 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
170                 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
171                 {
172                     const auto& input = ew.inputs[input_idx];
173                     if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
174                     {
175                         const float c = coefficients[input.index];
176                         if (c != 1.0f)
177                             coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
178                     }
179                 }
180
181                 input0_str = coeff_strings[0] + input0_str;
182                 input1_str = coeff_strings[1] + input1_str;
183             }
184
185
186             switch (ew.mode)
187             {
188             case EltwiseMode::ADD:          op += input0_str + " + " + input1_str; break;
189             case EltwiseMode::SUB:          op += input0_str + " - " + input1_str; break;
190             case EltwiseMode::MUL:          op += input0_str + " * " + input1_str; break;
191             case EltwiseMode::DIV:          op += input0_str + " / " + input1_str; break;
192             case EltwiseMode::MODULU:
193             case EltwiseMode::MIN:
194             case EltwiseMode::MAX:
195             {
196                 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
197                 auto input_0_type = params.inputs[0].GetDType();
198                 auto input_1_type = params.inputs[1].GetDType();
199
200                 // input_0 == int
201                 if (input_0_type == kernel_selector::Datatype::INT8 ||
202                     input_0_type == kernel_selector::Datatype::INT32 ||
203                     input_0_type == kernel_selector::Datatype::INT64)
204                 {
205                     // input_0 == int && input_1 == int
206                     if (input_1_type == kernel_selector::Datatype::INT8 ||
207                         input_1_type == kernel_selector::Datatype::INT32 ||
208                         input_1_type == kernel_selector::Datatype::INT64)
209                     {
210                         if (ew.mode == EltwiseMode::MODULU)
211                             op += input0_str + " % " + input1_str;
212                         else
213                             op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
214                     }
215                     // input_0 == int && input_1 != int
216                     else
217                     {
218                         op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
219                     }
220                 }
221                 // input_0 != int && input_1 == int
222                 else if (input_1_type == kernel_selector::Datatype::INT8 ||
223                     input_1_type == kernel_selector::Datatype::INT32 ||
224                     input_1_type == kernel_selector::Datatype::INT64)
225                 {
226                     op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
227                 }
228                 // input_0 != int && input_1 != int
229                 else
230                 {
231                     op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
232                 }
233             } break;
234             case EltwiseMode::POW:          op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
235             case EltwiseMode::SQRT:         op += cast_type + "sqrt(" + input0_str + ")"; break;
236             case EltwiseMode::RSQRT:        op += cast_type + "1/sqrt(" + input0_str + ")"; break;
237             case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
238                                                             " * (" + input0_str + " - " + input1_str + "))"; break;
239             case EltwiseMode::EQ:           op += cast_type + "(" + input0_str + " == " + input1_str + ")"; break;
240             case EltwiseMode::NE:           op += cast_type + "(" + input0_str + " != " + input1_str + ")"; break;
241             case EltwiseMode::LT:           op += cast_type + "(" + input0_str + " < " + input1_str + ")"; break;
242             case EltwiseMode::LE:           op += cast_type + "(" + input0_str + " <= " + input1_str + ")"; break;
243             case EltwiseMode::GT:           op += cast_type + "(" + input0_str + " > " + input1_str + ")"; break;
244             case EltwiseMode::GE:           op += cast_type + "(" + input0_str + " >= " + input1_str + ")"; break;
245             case EltwiseMode::LOGIC_AND:    op += cast_type + "(" + input0_str + " && " + input1_str + ")"; break;
246             case EltwiseMode::LOGIC_OR:     op += cast_type + "(" + input0_str + " || " + input1_str + ")"; break;
247             case EltwiseMode::LOGIC_XOR:    op += cast_type + "(!" + input0_str + " != !" + input1_str + ")"; break;
248             case EltwiseMode::ASSIGN:       op += input0_str; break;
249             default:
250                 break;
251             }
252
253             std::string opname = "OPERATION" + op_num_str;
254             jit.AddConstant(MakeJitConstant(opname, op));
255             do_eltwise += "\\\n\t" + opname + ";";
256         }
257
258         for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
259             do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
260             "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
261             ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
262
263         do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
264
265         jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
266
267         if (params.layoutBased || params.int8_quantization)
268         {
269             jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
270         }
271
272         if (!params.stride.empty())
273         {
274             jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
275         }
276
277         ///////////////
278         return jit;
279     }
280
281     KernelsData EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetKernelsData(const Params& params, const optional_params& options) const
282     {
283         return GetCommonKernelsData(params, options);
284     }
285 }