2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
20 #include "cldnn_defs.h"
21 #include "compounds.h"
31 /// @addtogroup cpp_api C++ API
34 /// @addtogroup cpp_topology Network Topology
37 /// @brief Globally unique primitive type id.
38 using primitive_type_id = cldnn_primitive_type_id;
39 /// @brief C API compatible unique @p id of a primitive within a topology.
40 using primitive_id_ref = cldnn_primitive_id;
41 /// @brief Unique @p id of a primitive within a topology.
42 using primitive_id = std::string;
44 /// @brief Dynamic cast to specified primitive description type.
46 typename PType::dto* as_dto(CLDNN_PRIMITIVE_DESC(primitive)* dto)
48 if (dto->type != PType::type_id()) throw std::invalid_argument("type");
49 return reinterpret_cast<typename PType::dto*>(dto);
52 /// @brief Dynamic cast to specified primitive description type.
54 const typename PType::dto* as_dto(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
56 if (dto->type != PType::type_id()) throw std::invalid_argument("type");
57 return reinterpret_cast<const typename PType::dto*>(dto);
60 /// @brief Base class of network primitive description.
63 /// @brief Initialize fields common for all primitives.
64 struct fixed_size_vector_ref
67 std::vector<primitive_id>& vref;
70 fixed_size_vector_ref(std::vector<primitive_id>& ref) : vref(ref)
73 auto size() const -> decltype(vref.size()) { return vref.size(); }
74 auto begin() const -> decltype(vref.begin()) { return vref.begin(); }
75 auto end() const -> decltype(vref.end()) { return vref.end(); }
76 auto cbegin() const -> decltype(vref.cbegin()) { return vref.cbegin(); }
77 auto cned() const -> decltype(vref.cend()) { return vref.cend(); }
79 primitive_id& operator[](size_t idx) { return vref[idx]; }
80 primitive_id const& operator[](size_t idx) const { return vref[idx]; }
82 primitive_id& at(size_t idx) { return vref.at(idx); }
83 primitive_id const& at(size_t idx) const { return vref.at(idx); }
85 primitive_id* data() { return vref.data(); }
86 const primitive_id* data() const { return vref.data(); }
88 const std::vector<primitive_id>& ref() const { return vref; }
92 const primitive_type_id& type,
93 const primitive_id& id,
94 const std::vector<primitive_id>& input,
95 const padding& output_padding = padding()
97 :type(type), id(id), input(_input.cpp_ids), output_padding(output_padding), _input(input)
100 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{primitive}
101 primitive(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
102 :type(dto->type), id(dto->id), input(_input.cpp_ids), output_padding(dto->output_padding), _input(dto->input)
105 virtual ~primitive() = default;
107 /// @brief Requested output padding.
108 /// @brief Requested output padding.
109 /// @brief Returns pointer to a C API primitive descriptor casted to @CLDNN_PRIMITIVE_DESC{primitive}.
110 virtual const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const = 0;
112 /// @brief Returns references to all primitive ids on which this primitive depends - inputs, weights, biases, etc.
113 std::vector<std::reference_wrapper<primitive_id>> dependencies()
115 std::vector<std::reference_wrapper<primitive_id>> result;
116 auto&& deps = get_dependencies();
118 result.reserve(_input.size() + deps.size());
119 for (auto& pid : _input.cpp_ids)
120 result.push_back(std::ref(pid));
121 for (auto& pid : deps)
122 result.push_back(std::ref(const_cast<primitive_id&>(pid.get())));
127 /// @brief Returns copy of all primitive ids on which this primitive depends - inputs, weights, biases, etc.
128 std::vector<primitive_id> dependencies() const
130 auto result = input.ref();
131 auto deps = get_dependencies();
132 result.insert(result.end(), deps.begin(), deps.end());
136 /// @brief Implicit conversion to primiitive id.
137 operator primitive_id() const { return id; }
139 /// @brief Primitive's type id.
140 const primitive_type_id type;
142 /// @brief Primitive's id.
143 const primitive_id id;
145 /// @brief List of ids of input primitives.
146 fixed_size_vector_ref input;
148 /// @brief Requested output padding.
149 padding output_padding;
152 struct primitive_id_arr
154 primitive_id_arr(std::vector<primitive_id> const& vec) : cpp_ids(vec)
157 primitive_id_arr(std::vector<primitive_id>&& vec) : cpp_ids(std::move(vec))
160 //create from C API id array
161 primitive_id_arr(cldnn_primitive_id_arr c_id_arr)
163 cpp_ids.resize(c_id_arr.size);
164 for (size_t i = 0; i < c_id_arr.size; ++i)
165 cpp_ids[i] = c_id_arr.data[i];
168 std::vector<primitive_id> cpp_ids;
169 mutable std::vector<cldnn_primitive_id> c_ids;
171 auto ref() const -> decltype(cldnn_primitive_id_arr{c_ids.data(), c_ids.size()})
173 c_ids.resize(cpp_ids.size());
174 for (size_t i = 0; i < cpp_ids.size(); ++i)
175 c_ids[i] = cpp_ids[i].c_str();
177 return cldnn_primitive_id_arr{ c_ids.data(), c_ids.size() };
180 size_t size() const { return cpp_ids.size(); }
183 primitive_id_arr _input;
185 virtual std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const { return{}; }
188 /// @brief base class for all primitives implementations.
189 template<class PType, class DTO>
190 class primitive_base : public primitive
193 /// @brief Returns pointer to a C API primitive descriptor casted to @CLDNN_PRIMITIVE_DESC{primitive}.
194 const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const override
196 //update common dto fields
197 _dto.id = id.c_str();
199 _dto.input = _input.ref();
200 _dto.output_padding = output_padding;
202 //call abstract method to update primitive-specific fields
204 return reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(&_dto);
208 explicit primitive_base(
209 const primitive_id& id,
210 const std::vector<primitive_id>& input,
211 const padding& output_padding = padding())
212 : primitive(PType::type_id(), id, input, output_padding)
215 primitive_base(const DTO* dto)
216 : primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
218 if (dto->type != PType::type_id())
219 throw std::invalid_argument("DTO type mismatch");
225 virtual void update_dto(DTO& dto) const = 0;
228 #define CLDNN_DEFINE_TYPE_ID(PType) static primitive_type_id type_id()\
230 return check_status<primitive_type_id>( #PType " type id failed", [](status_t* status)\
232 return cldnn_##PType##_type_id(status);\
236 #define CLDNN_DECLARE_PRIMITIVE(PType) typedef CLDNN_PRIMITIVE_DESC(PType) dto;\
237 CLDNN_DEFINE_TYPE_ID(PType)