2 // Copyright (c) 2016-2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
20 #include "kernel_selector_common.h"
26 namespace kernel_selector {
30 using JitDefinitions = std::vector<std::pair<std::string, std::string>>;
32 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
34 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
36 inline std::string GetTypeName() { throw std::runtime_error("Implement me"); }
38 inline std::string GetTypeName<int8_t>() { return "char"; }
40 inline std::string GetTypeName<uint8_t>() { return "uchar"; }
42 inline std::string GetTypeName<int16_t>() { return "short"; }
44 inline std::string GetTypeName<uint16_t>() { return "ushort"; }
46 inline std::string GetTypeName<int32_t>() { return "int"; }
48 inline std::string GetTypeName<uint32_t>() { return "uint"; }
50 inline std::string GetTypeName<int64_t>() { return "long"; }
52 inline std::string GetTypeName<uint64_t>() { return "ulong"; }
54 inline std::string GetTypeName<float>() { return "float"; }
56 inline std::string GetTypeName<double>() { return "double"; }
58 std::string toCLType(WeightsType wType);
59 std::string toCLType(Datatype dType);
60 std::string getMeanOpString(MeanOp op);
62 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
63 // ToCodeString functions
64 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
65 // TODO improve to_code_string specializations
67 std::string toCodeString(T val) { return std::to_string(val); }
69 inline std::string toCodeString(const std::string& val) { return val; }
70 inline std::string toCodeString(const char* val) { return val; }
71 inline std::string toCodeString(bool val) { return val ? "1" : "0"; }
72 std::string toCodeString(float val);
73 std::string toCodeString(double val);
75 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
77 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
78 template <typename VecT, typename ValT, typename Func>
79 inline std::string toVectorString(const VecT& vec, const std::string& vectorType, size_t maxDim, ValT padFillingVal, Func fetchFunc)
82 ss << "(" << vectorType << " []){ ";
83 for (size_t i = 0; i < vec.size(); i++)
84 ss << toCodeString(fetchFunc(vec[i])) << ",";
85 for (size_t i = vec.size(); i < maxDim; i++)
86 ss << padFillingVal << ",";
91 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
93 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
97 const std::string _name;
98 JitConstant(const std::string& name):_name(name){}
101 virtual JitDefinitions GetDefinitions() const = 0;
102 virtual ~JitConstant() {}
105 class simple_jit_constant : public JitConstant
107 const std::string _value;
110 simple_jit_constant(const std::string& name, const std::string& value)
111 : JitConstant(name), _value(value) {}
113 JitDefinitions GetDefinitions() const override
115 return JitDefinitions{ {_name, _value} };
120 std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, T value)
122 return std::static_pointer_cast<JitConstant>(std::make_shared<simple_jit_constant>(name, toCodeString(value)));
125 std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const struct Tensor::DataTensor& value);
126 std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const struct Tensor::WeightsTensor& value);
128 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
129 // VectorDataJitConstant
130 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
131 template <typename T>
132 class VectorDataJitConstant : public JitConstant
134 const std::vector<T> _data;
137 VectorDataJitConstant(const std::string& name, const std::vector<T>& data) : JitConstant(name), _data(data) {}
139 JitDefinitions GetDefinitions() const override
141 JitDefinitions result{
142 { _name + "_SIZE", toCodeString(_data.size()) },
143 { _name, toVectorString(_data, GetTypeName<T>(), _data.size(), 1, [](const T& v) {return v; } ) },
149 template <typename T>
150 inline std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const std::vector<T>& value)
152 return std::static_pointer_cast<JitConstant>(std::make_shared<VectorDataJitConstant<T>>(name, value));
155 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
157 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
158 template <typename T>
159 class SizeJitConstant : public JitConstant
164 SizeJitConstant(const std::string& name, const Size<T>& size) : JitConstant(name), _size(size) {}
166 JitDefinitions GetDefinitions() const override
168 JitDefinitions definitions{
169 { _name + "_SIZE_X", toCodeString(_size.x) },
170 { _name + "_SIZE_Y", toCodeString(_size.y) },
176 template <typename T>
177 inline std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const Size<T>& value)
179 return std::static_pointer_cast<JitConstant>(std::make_shared<SizeJitConstant<T>>(name, value));
182 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
184 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
185 template <typename T>
186 class DimVectorJitConstant : public JitConstant
188 const DimTensor<T> _dims;
191 DimVectorJitConstant(const std::string& name, const DimTensor<T>& size) : JitConstant(name), _dims(size) {}
193 JitDefinitions GetDefinitions() const override
195 JitDefinitions definitions{
196 { _name + "_BATCH_NUM", toCodeString(_dims.b) },
197 { _name + "_FEATURE_NUM", toCodeString(_dims.f) },
198 { _name + "_SIZE_Y", toCodeString(_dims.y) },
199 { _name + "_SIZE_X", toCodeString(_dims.x) },
205 template <typename T>
206 std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const DimTensor<T>& value)
208 return std::make_shared<DimVectorJitConstant<T>>(name, value);
211 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
213 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
216 std::vector<std::shared_ptr<JitConstant>> _constants;
218 JitConstants(std::initializer_list<std::shared_ptr<JitConstant>> constants) :_constants(constants) {}
220 inline void AddConstant(std::shared_ptr<JitConstant> constant)
222 _constants.push_back(constant);
225 inline void AddConstants(const std::vector<std::shared_ptr<JitConstant>>& constants)
227 for (const auto& c : constants)
229 _constants.push_back(c);
233 inline void Merge(const JitConstants& jit)
235 AddConstants(jit._constants);
238 JitDefinitions GetDefinitions() const;
241 JitConstants MakeActivationJitConstants(const base_activation_params& params, const std::string& suffix="");
242 JitConstants MakeBaseParamsJitConstants(const base_params& params);
243 JitConstants MakeLoopUnrollParamsJitConstants(uint32_t loopCount);
244 JitConstants MakeUnitTypeJitConstants(Datatype dataType);