2 // Copyright (c) 2018-2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "eltwise_kernel_fs_bs_yx_bsv4_fsv32.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
22 ParamsKey EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::INT8);
26 k.EnableOutputDataType(Datatype::INT8);
27 k.EnableInputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
28 k.EnableOutputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
29 k.EnableTensorOffset();
30 k.EnableTensorPitches();
32 k.EnableInt8Quantization();
33 k.EnableOutputCalibration();
34 k.EnableEltwiseStride();
38 EltwiseKernelBase::DispatchData EltwiseKernel_fs_bs_yx_bsv4_fsv32::SetDefault(const eltwise_params& params) const
42 kd.gws0 = params.output.X().v;
43 kd.gws1 = params.output.Y().v;
44 // we process 4 batches and 4 features per workitem
45 kd.gws2 = (params.output.Batch().v / 4) * (params.output.Feature().v / 4);
50 kd.effiency = FORCE_PRIORITY_3;
54 JitConstants EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetJitConstants(const eltwise_params& params) const
56 JitConstants jit = MakeBaseParamsJitConstants(params);
58 const size_t in_x_pitch = 32 * 4;
59 const size_t in_y_pitch = 32 * 4 * params.inputs[0].X().LogicalDimPadded();
60 const size_t in_b_block_pitch = in_y_pitch * params.inputs[0].Y().LogicalDimPadded();
61 const size_t in_f_block_pitch = in_b_block_pitch * ((params.inputs[0].Batch().v + 3) / 4);
62 const size_t in_offset = in_x_pitch * params.inputs[0].X().pad.before + in_y_pitch * params.inputs[0].Y().pad.before;
64 jit.AddConstant(MakeJitConstant("IN_X_PITCH", in_x_pitch));
65 jit.AddConstant(MakeJitConstant("IN_Y_PITCH", in_y_pitch));
66 jit.AddConstant(MakeJitConstant("IN_B_BLOCK_PITCH", in_b_block_pitch));
67 jit.AddConstant(MakeJitConstant("IN_F_BLOCK_PITCH", in_f_block_pitch));
68 jit.AddConstant(MakeJitConstant("IN_OFFSET", in_offset));
72 MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
73 MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
76 if (params.int8_quantization)
78 if (params.output_calibration)
80 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
81 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
85 jit.AddConstants({ MakeJitConstant("O_QF", params.output_quantization_factor) });
88 std::string inputs_decls;
89 auto& updateInputs = params.updateInputIds;
91 for (size_t i = 0; i < params.inputs.size(); i++)
93 //const should be added only to inputs which will not be updated
94 std::string const_str = "const";
95 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
97 if (updateInputs[update_input_idx].inputId == i)
104 inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
106 if (!params.stride.empty())
108 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
109 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
113 jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
114 jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
116 std::string do_eltwise;
118 auto& operations = params.operations;
119 auto& coefficients = params.coefficients;
121 for (size_t op_num = 0; op_num < operations.size(); op_num++)
123 const std::string op_num_str = std::to_string(op_num);
124 const auto& ew = operations[op_num];
126 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
128 const auto& input = ew.inputs[input_idx];
129 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
132 case EltwiseInputMode::SCALAR:
133 jit.AddConstant(MakeJitConstant(name, input.scalar));
135 case EltwiseInputMode::INPUT_BUFFER:
136 jit.AddConstant(MakeJitConstant(name, "GET_INPUT(input" + std::to_string(input.index) + ", INPUT" + std::to_string(input.index) + ")"));
138 case EltwiseInputMode::OUTPUT_BUFFER:
139 jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
141 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
142 jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
144 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
145 jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
151 std::string input0_str, input1_str, cast_type, op;
153 if (params.int8_quantization)
155 cast_type = "(int16)";
156 op = "const int16 tmp" + op_num_str + " = ";
160 cast_type = "(UNIT_TYPE)";
161 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
164 input0_str = cast_type + "INPUT_" + op_num_str + "_0";
165 input1_str = cast_type + "INPUT_" + op_num_str + "_1";
167 if (ew.mode == EltwiseMode::ADD)
169 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
170 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
172 const auto& input = ew.inputs[input_idx];
173 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
175 const float c = coefficients[input.index];
177 coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
181 input0_str = coeff_strings[0] + input0_str;
182 input1_str = coeff_strings[1] + input1_str;
188 case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
189 case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
190 case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
191 case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
192 case EltwiseMode::MODULU:
193 case EltwiseMode::MIN:
194 case EltwiseMode::MAX:
196 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
197 auto input_0_type = params.inputs[0].GetDType();
198 auto input_1_type = params.inputs[1].GetDType();
201 if (input_0_type == kernel_selector::Datatype::INT8 ||
202 input_0_type == kernel_selector::Datatype::INT32 ||
203 input_0_type == kernel_selector::Datatype::INT64)
205 // input_0 == int && input_1 == int
206 if (input_1_type == kernel_selector::Datatype::INT8 ||
207 input_1_type == kernel_selector::Datatype::INT32 ||
208 input_1_type == kernel_selector::Datatype::INT64)
210 if (ew.mode == EltwiseMode::MODULU)
211 op += input0_str + " % " + input1_str;
213 op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
215 // input_0 == int && input_1 != int
218 op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
221 // input_0 != int && input_1 == int
222 else if (input_1_type == kernel_selector::Datatype::INT8 ||
223 input_1_type == kernel_selector::Datatype::INT32 ||
224 input_1_type == kernel_selector::Datatype::INT64)
226 op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
228 // input_0 != int && input_1 != int
231 op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
234 case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
235 case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
236 case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
237 case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
238 " * (" + input0_str + " - " + input1_str + "))"; break;
239 case EltwiseMode::EQ: op += cast_type + "(" + input0_str + " == " + input1_str + ")"; break;
240 case EltwiseMode::NE: op += cast_type + "(" + input0_str + " != " + input1_str + ")"; break;
241 case EltwiseMode::LT: op += cast_type + "(" + input0_str + " < " + input1_str + ")"; break;
242 case EltwiseMode::LE: op += cast_type + "(" + input0_str + " <= " + input1_str + ")"; break;
243 case EltwiseMode::GT: op += cast_type + "(" + input0_str + " > " + input1_str + ")"; break;
244 case EltwiseMode::GE: op += cast_type + "(" + input0_str + " >= " + input1_str + ")"; break;
245 case EltwiseMode::LOGIC_AND: op += cast_type + "(" + input0_str + " && " + input1_str + ")"; break;
246 case EltwiseMode::LOGIC_OR: op += cast_type + "(" + input0_str + " || " + input1_str + ")"; break;
247 case EltwiseMode::LOGIC_XOR: op += cast_type + "(!" + input0_str + " != !" + input1_str + ")"; break;
248 case EltwiseMode::ASSIGN: op += input0_str; break;
253 std::string opname = "OPERATION" + op_num_str;
254 jit.AddConstant(MakeJitConstant(opname, op));
255 do_eltwise += "\\\n\t" + opname + ";";
258 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
259 do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
260 "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
261 ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
263 do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
265 jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
267 if (params.layoutBased || params.int8_quantization)
269 jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
272 if (!params.stride.empty())
274 jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
281 KernelsData EltwiseKernel_fs_bs_yx_bsv4_fsv32::GetKernelsData(const Params& params, const optional_params& options) const
283 return GetCommonKernelsData(params, options);