2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "eltwise_kernel_b_fs_yx_fsv4.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
22 ParamsKey EltwiseKernel_b_fs_yx_fsv4::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::INT8);
26 k.EnableInputDataType(Datatype::UINT8);
27 k.EnableOutputDataType(Datatype::INT8);
28 k.EnableOutputDataType(Datatype::UINT8);
29 k.EnableInputLayout(DataLayout::b_fs_yx_fsv4);
30 k.EnableOutputLayout(DataLayout::b_fs_yx_fsv4);
31 k.EnableTensorOffset();
32 k.EnableTensorPitches();
34 k.EnableInt8Quantization();
35 k.EnableOutputCalibration();
36 k.EnableEltwiseStride();
40 EltwiseKernelBase::DispatchData EltwiseKernel_b_fs_yx_fsv4::SetDefault(const eltwise_params& params) const
44 // Because of very specific requirements for data, we may linearize the data,
45 // i.e. use only one dimension, e.g. 'X'.
48 // we process 4*4 (4 int8 bytes per on block_read4 reading) features per workitem
49 kd.gws0 = params.output.X().v * params.output.Y().v *
50 params.output.Batch().v * params.output.Feature().v / (4*4);
58 kd.effiency = FORCE_PRIORITY_1;
62 bool EltwiseKernel_b_fs_yx_fsv4::Validate(const Params& params, const optional_params& options) const
64 // Requirents to use 'eltwise_b_fs_yx_fsv4' kernel are below:
66 // 2. All dimensions for all inputs are the same
68 // So, it can be linearized
70 if (!Parent::Validate(params, options)) {
74 KernelData kd = KernelData::Default<eltwise_params>(params);
75 eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
78 if (!newParams.stride.empty()) {
82 for (size_t i = 0; i < newParams.inputs.size() - 1; i++)
84 // 2. All dimensions for all inputs are the same
85 if (!(newParams.inputs[i] == newParams.inputs[i + 1])) {
90 const auto& in = newParams.inputs[0];
91 for (size_t i = 0; i < in.Dimentions(); i++)
94 if ((in.GetDims()[i].pad.before != 0) ||
95 (in.GetDims()[i].pad.after != 0)) {
103 JitConstants EltwiseKernel_b_fs_yx_fsv4::GetJitConstants(const eltwise_params& params) const
105 JitConstants jit = MakeBaseParamsJitConstants(params);
107 if (params.inputs[0].GetDType() == Datatype::UINT8) {
108 // Special handler for unsigned types
110 MakeJitConstant("ELTW_UNSIGNED", 1)
116 MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
117 MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
120 if (params.int8_quantization)
122 if (params.output_calibration)
124 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
125 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
129 jit.AddConstants({ MakeJitConstant("O_QF", params.output_quantization_factor) });
132 std::string inputs_decls;
133 auto& updateInputs = params.updateInputIds;
135 for (size_t i = 0; i < params.inputs.size(); i++)
137 //const should be added only to inputs which will not be updated
138 std::string const_str = "const";
139 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
141 if (updateInputs[update_input_idx].inputId == i)
148 inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
151 jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
152 jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
154 std::string do_eltwise;
156 auto& operations = params.operations;
157 auto& coefficients = params.coefficients;
159 for (size_t op_num = 0; op_num < operations.size(); op_num++)
161 const std::string op_num_str = std::to_string(op_num);
162 const auto& ew = operations[op_num];
164 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
166 const auto& input = ew.inputs[input_idx];
167 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
170 case EltwiseInputMode::SCALAR:
171 jit.AddConstant(MakeJitConstant(name, input.scalar));
173 case EltwiseInputMode::INPUT_BUFFER:
174 jit.AddConstant(MakeJitConstant(name, "GET_INPUT(input" + std::to_string(input.index) + ", INPUT" + std::to_string(input.index) + ")"));
176 case EltwiseInputMode::OUTPUT_BUFFER:
177 jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
179 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
180 jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
182 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
183 jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
189 std::string input0_str, input1_str, cast_type, op;
191 cast_type = "(int16)";
192 op = "const int16 tmp" + op_num_str + " = ";
194 input0_str = cast_type + "INPUT_" + op_num_str + "_0";
195 input1_str = cast_type + "INPUT_" + op_num_str + "_1";
197 if (ew.mode == EltwiseMode::ADD)
199 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
200 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
202 const auto& input = ew.inputs[input_idx];
203 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
205 const float c = coefficients[input.index];
207 coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
211 input0_str = coeff_strings[0] + input0_str;
212 input1_str = coeff_strings[1] + input1_str;
218 case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
219 case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
220 case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
221 case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
222 case EltwiseMode::MODULU:
223 case EltwiseMode::MIN:
224 case EltwiseMode::MAX:
226 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
227 auto input_0_type = params.inputs[0].GetDType();
228 auto input_1_type = params.inputs[1].GetDType();
231 if (input_0_type == kernel_selector::Datatype::INT8 ||
232 input_0_type == kernel_selector::Datatype::UINT8)
234 // input_0 == int && input_1 == int
235 if (input_1_type == kernel_selector::Datatype::INT8 ||
236 input_1_type == kernel_selector::Datatype::UINT8)
238 if (ew.mode == EltwiseMode::MODULU)
239 op += input0_str + " % " + input1_str;
241 op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
243 // input_0 == int && input_1 != int
246 op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
249 // input_0 != int && input_1 == int
250 else if (input_1_type == kernel_selector::Datatype::INT8 ||
251 input_1_type == kernel_selector::Datatype::UINT8)
253 op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
255 // input_0 != int && input_1 != int
258 op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
261 case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
262 case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
263 case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
264 case EltwiseMode::ASSIGN: op += input0_str; break;
269 std::string opname = "OPERATION" + op_num_str;
270 jit.AddConstant(MakeJitConstant(opname, op));
271 do_eltwise += "\\\n\t" + opname + ";";
274 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
275 do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
276 "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
277 ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
279 do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
281 jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
283 if (params.layoutBased || params.int8_quantization)
285 jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
288 if (!params.stride.empty())
290 jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
297 KernelsData EltwiseKernel_b_fs_yx_fsv4::GetKernelsData(const Params& params, const optional_params& options) const
299 return GetCommonKernelsData(params, options);