1 // Copyright (c) 2016-2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "eltwise_kernel_base.h"
16 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
21 static uint32_t GetNumberOfInputs(EltwiseMode m) {
23 case EltwiseMode::ADD:
24 case EltwiseMode::SUB:
25 case EltwiseMode::MUL:
26 case EltwiseMode::DIV:
27 case EltwiseMode::MIN:
28 case EltwiseMode::MAX:
29 case EltwiseMode::POW:
30 case EltwiseMode::MODULU:
37 case EltwiseMode::LOGIC_AND:
38 case EltwiseMode::LOGIC_OR:
39 case EltwiseMode::LOGIC_XOR:
40 case EltwiseMode::SQUARED_DIFF:
41 case EltwiseMode::FLOOR_MOD:
43 case EltwiseMode::SQRT:
44 case EltwiseMode::RSQRT:
45 case EltwiseMode::ASSIGN:
52 ParamsKey eltwise_params::GetParamsKey() const {
53 ParamsKey k = base_params::GetParamsKey();
54 if (int8_quantization) {
55 k.EnableInt8Quantization();
58 if (output_calibration) {
59 k.EnableOutputCalibration();
62 if (inputs_calibration) {
63 k.EnableEltwiseInputsCalibration();
66 if (!stride.empty()) {
67 k.EnableEltwiseStride();
71 k.EnableEltwiseBroadcast();
77 bool EltwiseKernelBase::Validate(const Params& p, const optional_params& o) const {
78 if (p.GetType() != KernelType::ELTWISE || o.GetType() != KernelType::ELTWISE) {
82 const eltwise_params& params = static_cast<const eltwise_params&>(p);
84 if (params.inputs.size() == 0) {
88 auto& operations = params.operations;
90 if (operations.size() == 0) {
94 for (size_t op_num = 0; op_num < operations.size(); op_num++) {
95 const auto& ew = operations[op_num];
97 if (ew.inputs.size() != GetNumberOfInputs(ew.mode)) {
101 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
102 const auto& input = ew.inputs[input_idx];
103 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index >= params.inputs.size()) {
112 JitConstants EltwiseKernelBase::GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const {
113 JitConstants jit = MakeBaseParamsJitConstants(params);
115 auto GetIdxOrderForLayout = [&](DataLayout l, bool layoutBased, uSize stride) -> std::string {
116 // TODO: Generalize this method
117 std::vector<std::string> bfyx_idx_order = {};
119 bfyx_idx_order = { "d4", "d3", "d2", "d1" };
121 if (l == DataLayout::yxfb) {
122 bfyx_idx_order = { "d1", "d2", "d3", "d4" };
123 } else if (l == DataLayout::fyxb) {
124 bfyx_idx_order = { "d1", "d4", "d3", "d2" };
126 bfyx_idx_order = { "d4", "d3", "d2", "d1" };
130 if (!params.stride.empty()) {
131 bfyx_idx_order[2] = "(" + bfyx_idx_order[2] + "*" + std::to_string(stride.y) + ")";
132 bfyx_idx_order[3] = "(" + bfyx_idx_order[3] + "*" + std::to_string(stride.x) + ")";
135 return bfyx_idx_order[0] + "," +
136 bfyx_idx_order[1] + "," +
137 bfyx_idx_order[2] + "," +
142 MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
143 MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
144 MakeJitConstant("ELTWISE_BROADCAST", params.broadcast),
147 if (params.int8_quantization) {
148 if (params.output_calibration) {
149 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
150 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
153 jit.AddConstants({MakeJitConstant("O_QF", params.output_quantization_factor)});
157 std::string inputs_decls, vload_decls;
158 auto& updateInputs = params.updateInputIds;
160 std::string out_idx_order = "OUTPUT_IDX_ORDER";
161 uSize out_stride = {1, 1, 1};
163 jit.AddConstant(MakeJitConstant(out_idx_order, "d1"));
165 if (CheckInputsOutputNoPitchSameDims(params) &&
166 !(params.layoutBased || params.int8_quantization || params.broadcast)) {
167 jit.AddConstant(MakeJitConstant(out_idx_order, "d1"));
169 size_t out_c = DataTensor::ChannelsCount(params.output.GetLayout());
171 jit.AddConstant(MakeJitConstant(out_idx_order, GetIdxOrderForLayout(params.output.GetLayout(),
172 params.layoutBased || params.broadcast,
174 } else if (out_c == 5) {
175 jit.AddConstant(MakeJitConstant(out_idx_order, "d5,d4,d3,d2,d1"));
181 if (!params.stride.empty()) {
182 jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_X", out_stride.x));
183 jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_Y", out_stride.y));
184 jit.AddConstant(MakeJitConstant("OUTPUT_STRIDE_Z", out_stride.z));
187 for (size_t i = 0; i < params.inputs.size(); i++) {
188 // const should be added only to inputs which will not be updated
189 std::string const_str = "const";
190 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++) {
191 if (updateInputs[update_input_idx].inputId == i) {
198 const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
199 if (!params.stride.empty()) {
200 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
201 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
202 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Z", params.stride[i].z));
204 std::string idx_order = "INPUT" + std::to_string(i) + "_IDX_ORDER";
206 vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
207 if (params.inputs[i].PhysicalSize() == 1) // Scalar case
208 vload_decls += " = (" + toCLType(params.inputs[i].GetDType()) + "8)(input" + std::to_string(i) + "[0]";
210 vload_decls += " = vload8(global_id, input" + std::to_string(i);
212 jit.AddConstant(MakeJitConstant(idx_order, "d1"));
214 if (CheckInputsOutputNoPitchSameDims(params) &&
215 !(params.layoutBased || params.int8_quantization || params.broadcast)) {
216 jit.AddConstant(MakeJitConstant(idx_order, "d1"));
218 size_t in_c = DataTensor::ChannelsCount(params.inputs[i].GetLayout());
219 size_t out_c = DataTensor::ChannelsCount(params.output.GetLayout());
220 auto in_stride = params.stride.empty() ? out_stride : params.stride[i];
221 if (out_c <= 4 && in_c <= 4) {
222 jit.AddConstant(MakeJitConstant(idx_order, GetIdxOrderForLayout(params.inputs[i].GetLayout(),
223 params.layoutBased || params.broadcast,
225 } else if (out_c == 5) {
227 // Skip Z coord for 4d tensors
228 jit.AddConstant(MakeJitConstant(idx_order, "d5,d4,d2,d1"));
229 } else if (in_c == 5) {
230 jit.AddConstant(MakeJitConstant(idx_order, "d5,d4,d3,d2,d1"));
232 } else if (out_c <= 4 && in_c == 5) {
233 // quite strange case, but can happen due to reorders fusing
234 // it means that z coord is equal to 1, so z offset will be always equal to 0
235 jit.AddConstant(MakeJitConstant(idx_order, "d4,d3,0,d2,d1"));
243 jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
244 jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
247 jit.AddConstant(MakeJitConstant("VLOAD_DECLS", vload_decls));
249 std::string do_eltwise;
251 auto& operations = params.operations;
252 auto& coefficients = params.coefficients;
254 for (size_t op_num = 0; op_num < operations.size(); op_num++) {
255 const std::string op_num_str = std::to_string(op_num);
256 const auto& ew = operations[op_num];
258 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
259 const auto& input = ew.inputs[input_idx];
260 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
261 std::string idx_order = "INPUT" + std::to_string(input.index) + "_IDX_ORDER";
263 switch (input.mode) {
264 case EltwiseInputMode::SCALAR:
265 jit.AddConstant(MakeJitConstant(name, input.scalar));
267 case EltwiseInputMode::INPUT_BUFFER:
269 jit.AddConstant(MakeJitConstant(name, "in" + std::to_string(input.index)));
271 jit.AddConstant(MakeJitConstant(name,
272 "input" + std::to_string(input.index) +
273 "[GET_INDEX(INPUT, " + std::to_string(input.index) +
274 "," + idx_order +")]"));
276 case EltwiseInputMode::OUTPUT_BUFFER:
277 jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT,,"+ out_idx_order +")]"));
279 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
280 jit.AddConstant(MakeJitConstant(
282 "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
284 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
285 jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
292 std::string input0_str, input1_str, cast_type, output_cast, op;
295 cast_type = "(MAKE_VECTOR_TYPE(UNIT_TYPE, 8))";
296 op = "const MAKE_VECTOR_TYPE(UNIT_TYPE, 8) tmp" + op_num_str + " = ";
297 } else if (params.int8_quantization) {
299 op = "const int tmp" + op_num_str + " = ";
301 cast_type = "(UNIT_TYPE)";
302 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
305 if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
306 output_cast = "(char)";
307 cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
310 input0_str = cast_type + "INPUT_" + op_num_str + "_0";
311 input1_str = cast_type + "INPUT_" + op_num_str + "_1";
313 if (ew.mode == EltwiseMode::ADD) {
314 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
315 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++) {
316 const auto& input = ew.inputs[input_idx];
317 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size()) {
318 const float c = coefficients[input.index];
320 coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
324 input0_str = coeff_strings[0] + input0_str;
325 input1_str = coeff_strings[1] + input1_str;
329 case EltwiseMode::ADD:
330 op += input0_str + " + " + input1_str;
332 case EltwiseMode::SUB:
333 op += input0_str + " - " + input1_str;
335 case EltwiseMode::MUL:
336 op += input0_str + " * " + input1_str;
338 case EltwiseMode::DIV:
339 op += input0_str + " / " + input1_str;
341 case EltwiseMode::MODULU:
342 case EltwiseMode::MIN:
343 case EltwiseMode::MAX: {
344 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
345 auto input_0_type = params.inputs[0].GetDType();
346 auto input_1_type = params.inputs[1].GetDType();
349 if (input_0_type == kernel_selector::Datatype::INT8 ||
350 input_0_type == kernel_selector::Datatype::INT32 ||
351 input_0_type == kernel_selector::Datatype::INT64) {
352 // input_0 == int && input_1 == int
353 if (input_1_type == kernel_selector::Datatype::INT8 ||
354 input_1_type == kernel_selector::Datatype::INT32 ||
355 input_1_type == kernel_selector::Datatype::INT64) {
356 if (ew.mode == EltwiseMode::MODULU)
357 op += input0_str + " % " + input1_str;
359 op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
361 // input_0 == int && input_1 != int
362 op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
364 } else if (input_1_type == kernel_selector::Datatype::INT8 ||
365 input_1_type == kernel_selector::Datatype::INT32 ||
366 input_1_type == kernel_selector::Datatype::INT64) {
367 // input_0 != int && input_1 == int
368 op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
370 // input_0 != int && input_1 != int
371 op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
374 case EltwiseMode::POW:
375 op += cast_type + "pow(" + input0_str + ", " + input1_str + ")";
377 case EltwiseMode::SQRT:
378 op += cast_type + "sqrt(" + input0_str + ")";
380 case EltwiseMode::RSQRT:
381 op += cast_type + "1/sqrt(" + input0_str + ")";
383 case EltwiseMode::SQUARED_DIFF:
384 op += cast_type + "((" + input0_str + " - " + input1_str +
387 input0_str + " - " + input1_str + "))";
389 case EltwiseMode::EQ:
390 op += output_cast + "(" + input0_str + " == " + input1_str + ")";
392 case EltwiseMode::NE:
393 op += output_cast + "(" + input0_str + " != " + input1_str + ")";
395 case EltwiseMode::LT:
396 op += output_cast + "(" + input0_str + " < " + input1_str + ")";
398 case EltwiseMode::LE:
399 op += output_cast + "(" + input0_str + " <= " + input1_str + ")";
401 case EltwiseMode::GT:
402 op += output_cast + "(" + input0_str + " > " + input1_str + ")";
404 case EltwiseMode::GE:
405 op += output_cast + "(" + input0_str + " >= " + input1_str + ")";
407 case EltwiseMode::LOGIC_AND:
408 op += output_cast + "(" + input0_str + " && " + input1_str + ")";
410 case EltwiseMode::LOGIC_OR:
411 op += output_cast + "(" + input0_str + " || " + input1_str + ")";
413 case EltwiseMode::LOGIC_XOR:
414 op += output_cast + "(!" + input0_str + " != !" + input1_str + ")";
416 case EltwiseMode::FLOOR_MOD:
417 op += output_cast + "(" + input0_str + " - " + input0_str + " / " + input1_str + " * " + input1_str + ")";
419 case EltwiseMode::ASSIGN:
426 std::string opname = "OPERATION" + op_num_str;
427 jit.AddConstant(MakeJitConstant(opname, op));
428 do_eltwise += "\\\n\t" + opname + ";";
431 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
432 do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) + "[GET_INDEX(INPUT, " +
433 std::to_string(updateInputs[update_input_idx].inputId) + ", " +
434 "INPUT"+std::to_string(updateInputs[update_input_idx].inputId) + "_IDX_ORDER)] = tmp" +
435 std::to_string(updateInputs[update_input_idx].tmpId) + ";";
437 do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
439 jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
441 if (params.layoutBased || params.int8_quantization || params.broadcast) {
442 jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
445 if (!params.stride.empty()) {
446 jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
452 JitConstants EltwiseKernelBase::GetJitConstants(const eltwise_params& params) const {
453 return GetJitConstantsCommon(params, false);
456 EltwiseKernelBase::DispatchData EltwiseKernelBase::SetDefault(const eltwise_params& params) const {
459 if (params.layoutBased || params.int8_quantization || params.broadcast) {
460 auto global = GetTensorFriendlyWorkGroups(params.output);
464 } else if (CheckInputsOutputNoPitchSameDims(params)) {
465 kd.gws0 = params.output.LogicalSize();
469 const auto& out = params.output;
471 std::vector<size_t> gws;
472 for (const auto& o : out.GetDims()) {
477 if (out.GetLayout() == DataLayout::bfzyx)
482 for (size_t i = gws.size(); i < n_dims; i++) {
488 kd.gws1 = gws[1] * gws[2]; // y*z
489 kd.gws2 = gws[3] * gws[4];
492 kd.gws2 = gws[2] * gws[3];
496 auto local = GetOptimalLocalWorkGroupSizes({kd.gws0, kd.gws1, kd.gws2});
498 if (params.output.GetLayout() == DataLayout::bfyx_f16 && params.output.Feature().v % 16 == 0 &&
512 KernelsData EltwiseKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const {
513 if (!Validate(params, options)) {
517 KernelData kd = KernelData::Default<eltwise_params>(params);
518 eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
520 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
521 auto cldnn_jit = GetJitConstants(newParams);
522 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
524 DispatchData runInfo = SetDefault(newParams);
526 auto& kernel = kd.kernels[0];
528 kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
529 kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
531 kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
532 kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(),
535 newParams.int8_quantization,
536 newParams.output_calibration);
538 kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
542 } // namespace kernel_selector