2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "common_kernel_base.h"
20 #if defined __INTEL_COMPILER
21 #pragma warning disable: 177
24 namespace kernel_selector
30 std::ostringstream oss;
32 std::vector<std::string> defined_macroses;
34 CodeBuilder& register_macro(const std::string& name)
36 assert(std::count(defined_macroses.begin(), defined_macroses.end(), name) == 0);
37 defined_macroses.push_back(name);
42 CodeBuilder& set_code(const std::string& c)
49 CodeBuilder& add_line(const std::string& line) {
54 CodeBuilder& decoration_macro(const std::string& name, const std::string& prefix, const std::string& postfix, const std::string& name_prefix = std::string())
56 oss << "#define " << name << "(name) " << prefix << " " + name_prefix + "_##" + "name" << (postfix.empty() ? "" : "##_") << postfix << std::endl;
57 return register_macro(name);
61 CodeBuilder& value_macro(const std::string& name, const std::string& value)
63 oss << "#define " << name << " " << value << std::endl;
64 return register_macro(name.substr(0, name.find('(')));
69 std::ostringstream os;
71 os << code << std::endl;
77 std::string common_kernel_base::GetEntryPoint(const std::string& templateName, const std::string& layerID, const optional_params& options) const
79 std::string kernelID = layerID;
81 if (kernelID.empty() || !options.meaningfulKernelsNames)
83 kernelID = templateName;
86 std::replace(kernelID.begin(), kernelID.end(), '.', '_');
87 std::replace(kernelID.begin(), kernelID.end(), '/', '_');
89 kernelID += "_" + std::to_string(UniqeID());
94 std::string common_kernel_base::CreateJit(const std::string& template_name, const JitConstants& constants, const std::string& kernel_id) const
96 class CodeBuilder code;
97 code.add_line("\n//====================================================")
98 .add_line("// Kernel template: " + template_name + " ")
99 .add_line("// Kernel name: " + kernel_id)
100 .value_macro("KERNEL(name)", "__kernel void " + kernel_id)
101 .decoration_macro("FUNC", "", kernel_id)
102 .decoration_macro("FUNC_CALL", "", kernel_id);
104 for (auto& definition : constants.GetDefinitions())
106 code.value_macro(definition.first, definition.second);
109 std::string jit = code.str();
114 Arguments common_kernel_base::GetArgsDesc(uint32_t num_of_input, bool use_weights, bool use_bias, bool use_quantization, bool use_output_calibration) const
118 for (uint32_t i = 0; i < num_of_input; i++)
120 args.push_back({ ArgumentDescriptor::Types::INPUT, i });
123 args.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
127 args.push_back({ ArgumentDescriptor::Types::WEIGHTS, 0 });
132 args.push_back({ ArgumentDescriptor::Types::BIAS, 0 });
135 if (use_quantization && use_weights)
137 args.push_back({ ArgumentDescriptor::Types::WEIGHTS_QUANTIZATION_FACTORS, 0 });
140 if (use_output_calibration)
142 args.push_back({ ArgumentDescriptor::Types::OUTPUT_CALIBRATION_FACTORS, 0 });
148 std::shared_ptr<KernelString> common_kernel_base::GetKernelString(const std::string& name, const std::string& jit, const std::string& entry_point, const EngineInfo& engine_info, const std::string& exe_mode) const
150 std::shared_ptr<KernelString> kernel_string = std::make_shared<KernelString>();
152 auto codes = db.get(name);
156 kernel_string->str = codes[0];
157 kernel_string->jit = jit;
158 kernel_string->options = exe_mode + " -cl-mad-enable";
159 if (engine_info.bIMMADSupport)
160 kernel_string->options += " -DMMAD_SUPPORTED=1";
161 if (engine_info.bIMADSupport)
162 kernel_string->options += " -DIMAD_SUPPORTED=1";
163 kernel_string->entry_point = entry_point;
164 kernel_string->batch_compilation = true;
167 return kernel_string;
170 static void Check_RunInfoData(const std::string &kernelName, const kernel_selector::CommonDispatchData &runInfo)
172 if (runInfo.lws0 * runInfo.lws1 * runInfo.lws2 > 256)
174 std::cout << "ERROR: dispatch data for kernel: " << kernelName << " LWS cannot be greater than 256!\n" << std::endl;
176 if (runInfo.gws0 == 0 || runInfo.gws1 == 0 || runInfo.gws2 == 0 || runInfo.lws0 == 0 || runInfo.lws1 == 0 || runInfo.lws2 == 0)
178 std::cout << "ERROR: dispatch data for kernel: " << kernelName << " dispatch data cannot contain zeros!" << std::endl;
180 if (runInfo.gws0 % runInfo.lws0 != 0)
182 std::cout << "ERROR: dispatch data for kernel: " << kernelName << " is incorrect: GWS0: " << runInfo.gws0 << " LWS0: " << runInfo.lws0 << std::endl;
184 if (runInfo.gws1 % runInfo.lws1 != 0)
186 std::cout << "ERROR: dispatch data for kernel: " << kernelName << " is incorrect: GWS1: " << runInfo.gws1 << " LWS1: " << runInfo.lws1 << std::endl;
188 if (runInfo.gws2 % runInfo.lws2 != 0)
190 std::cout << "ERROR: dispatch data for kernel: " << kernelName << " is incorrect: GWS2: " << runInfo.gws2 << " LWS2: " << runInfo.lws2 << std::endl;
194 void common_kernel_base::FillCLKernelData(clKernelData& kernel, const CommonDispatchData& runInfo, const EngineInfo& engine_info,
195 const std::string& kernelMapName, const std::string& jit, const std::string& entryPoint, const std::string& exeMode, bool weights, bool bias, int number_of_inputs, bool quantization, bool calibration) const
197 Check_RunInfoData(kernelMapName, runInfo);
198 kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
199 kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
200 kernel.kernelString = GetKernelString(kernelMapName, jit, entryPoint, engine_info, exeMode);
201 kernel.arguments = GetArgsDesc(number_of_inputs, weights, bias, quantization, calibration);