2 // Copyright (c) 2017 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
20 #include "api/CPP/primitive.hpp"
21 #include "internal_primitive.h"
23 #include "meta_utils.h"
30 class graph_initializations;
33 struct typed_program_node;
35 template <class PType>
36 struct internal_primitive_type_base;
42 Base class for all primitives which wraps API class and extends it to be used
45 Besides primitive description provided by user, this class includes functionality to
46 ask for direct predecessors and succesors as well as takes care of changes to primitive
47 which would affect other graph's nodes (the most commont case is probably calculating output layout).
49 At graph level, all connections between nodes are directly stored inside program_nodes - in oposite
50 to API level where all primitives store only ids of related ones.
54 friend struct program_impl; // to be removed when possible
55 friend class compile_graph; // to be removed when possible
56 friend class graph_initializations; // to be removed when possible
57 friend class prepare_primitive_fusing; // to be removed when possible
58 friend class prepare_conv_eltw_fusing; // to be removed when possible
59 friend class prepare_conv_eltw_read_write_opt; // to be removed when possible
60 friend class propagate_constants; // to be removed when possible
61 friend class post_optimize_weights; // to be removed when possible - requires an access to selected_impl
63 template <class PType>
64 friend struct typed_program_node;
66 program_node(std::shared_ptr<primitive> prim, program_impl& prog);
68 program_node(program_node const&) = delete;
70 virtual ~program_node() = default;
73 virtual const primitive_id& id() const { return desc->id; }
74 virtual primitive_type_id type() const { return desc->type; }
76 template <class PType>
79 static_assert(meta::is_primitive<PType>::value, "Type argument for program_node::is_type should be a non-const, non-volatile type derived from primitive");
80 return type() == PType::type_id();
83 program_impl& get_program() { return myprog; }
84 program_impl const& get_program() const { return myprog; }
86 std::shared_ptr<primitive_impl> get_selected_impl() const { return selected_impl; }
88 std::vector<program_node*> const& get_dependencies() const { return dependencies; }
89 program_node& get_dependency(size_t idx) const { return *dependencies.at(idx); }
91 //replaces idx-th dependency of 'this' with 'new_dep', calls program::remove_if_dangling(old_dep)
92 void replace_dependency(size_t idx, program_node& new_dep);
93 //searches for 'old_dep' in dependencies list of 'this' and replaces it with 'new_dep', calls program::remove_if_dangling(old_dep)
94 void replace_dependency(program_node const& old_dep, program_node& new_dep);
96 std::vector<primitive_id> get_dependencies_ids() const;
98 void remove_dependency(size_t idx);
99 void remove_dependency(program_node& node);
101 std::set<primitive_id> get_memory_dependencies() const;
102 void add_memory_dependency(primitive_id);
103 void add_memory_dependency(std::vector<primitive_id>);
105 template<class PType>
106 bool have_user_with_type() const
108 for (auto const& usr : users)
110 if (usr->is_type<PType>()) return true;
115 bool is_detached(bool whole_branch = false);
117 std::list<program_node*> const& get_users() { return users; }
118 // for const method, add const to stored successors/predecessors
119 std::list<const program_node*> const& get_users() const { return reinterpret_cast<const std::list<const program_node*>&>(users); }
121 std::unique_ptr<json_composite> desc_to_json() const;
122 //do not modify primitive directly to keep synchronisation with graph
123 std::shared_ptr<const primitive> get_primitive() const { return desc; }
124 //primitive modification functions
125 void set_output_padding(padding const& padd)
127 //changing output padding shouldn't cause any changes to other primitives
129 output_layout.data_padding = padd;
132 void merge_output_padding(padding const& padd)
134 set_output_padding(padding::max(padd, output_layout.data_padding));
137 //only calculated output layout (for external usage), does not modify/use cached output layout nor invalidate users
138 layout calc_output_layout() const;
140 //uses cached output layout if valid, if not calls 'calc_output_layout' and stores its result + invalidate all users if layout has changed and @p invalidate_users_if_changed is set to true
141 layout get_output_layout(bool invalidate_users_if_changed = true);
142 //returns cached output layout if valid, otherwise throws an exception
143 layout get_output_layout() const;
144 //returns result of get_output_layout without padding
145 layout get_non_padded_output_layout(bool invalidate_users_if_changed = true);
147 //sets cached output layout to an arbitrary value, invalidates users if new layout differs from previous one and @p invalidate_users_if_changed is set to true
148 //returns whether output layout has changed
149 bool set_output_layout(layout new_layout, bool invalidate_users_if_changed = true);
151 //forces recalculation of cached output layout, invalidates users if new layout is different than previous one and @p invalidate_users_if_changed is set to true
152 //returns whether output layout has changed
153 bool recalc_output_layout(bool invalidate_users_if_changed = true);
155 bool is_padded() { return static_cast<bool>(get_output_layout().data_padding); }
156 bool is_padded() const { return static_cast<bool>(get_output_layout().data_padding); }
158 bool has_padded_dependency();
159 bool has_padded_dependency() const;
161 bool is_input() const { return dependencies.empty(); }
162 bool is_endpoint() const { return users.empty(); }
163 void set_output(bool out) { output = out; }
164 bool is_output() const { return output; }
166 bool is_valid_output_layout() const { return valid_output_layout; }
168 uint8_t mark(uint8_t val = 1) { uint8_t ret = user_mark; user_mark = val; return ret; }
169 void unmark() { user_mark = 0; }
170 bool is_marked() const { return user_mark != 0; }
171 bool is_marked(uint8_t val) const { return user_mark == val; }
172 uint8_t get_user_mark() const { return user_mark; }
174 void set_fused_activation(cldnn_activation_func activation_func, cldnn_activation_additional_params additional_params)
176 fused_activation.activation_func = activation_func;
177 fused_activation.additional_params = additional_params;
180 cldnn_activation_func get_fused_activation_func() const
182 return fused_activation.activation_func;
185 cldnn_activation_additional_params get_fused_activation_params() const
187 return fused_activation.additional_params;
190 // check/set if the node can be optimized out (removed from the network)
191 bool can_be_optimized() const { return optimized; }
192 void can_be_optimized(bool opt) { optimized = opt; }
194 // check/set if the node's buffer can be shared during the memory pool optimization
195 bool can_share_buffer() const { return share_buffer; }
196 void can_share_buffer(bool share) { share_buffer = share; }
198 // check/set if the node support padding in x,y,b and f
199 bool support_padding() const { return _support_padding; }
200 void support_padding(bool support) { _support_padding = support; }
202 primitive_id get_org_primitive_id() const { return org_id; }
204 bool is_constant() const { return constant; }
206 // returns true if this node is within main data flow of the network (i.e. it does not describe helper data like convolution's weights etc.)
207 bool is_in_data_flow() const { return data_flow; }
209 //conversion from generic to specific
210 template <class To, class..., class = typename std::enable_if<!std::is_same<To, primitive>::value>::type>
211 typed_program_node<To>& as()
213 if (type() != To::type_id())
214 throw std::invalid_argument("program_node: mismatching primitive's type");
216 return reinterpret_cast<typed_program_node<To>&>(*this);
219 template <class To, class..., class = typename std::enable_if<!std::is_same<To, primitive>::value>::type>
220 typed_program_node<To> const& as() const
222 if (type() != To::type_id())
223 throw std::invalid_argument("program_node: mismatching primitive's type");
225 return reinterpret_cast<typed_program_node<To> const&>(*this);
229 operator typed_program_node<To>& ()
235 operator typed_program_node<To> const& () const
240 void set_reused_memory_color(uint32_t color) const
242 has_reused_memory = true;
243 reused_memory_color = color;
246 bool is_reusing_memory() { return has_reused_memory; };
247 uint32_t get_reused_memory_color() { return reused_memory_color; ; }
250 std::shared_ptr<primitive> desc;
251 program_impl& myprog;
253 std::shared_ptr<primitive_impl> selected_impl;
255 bool valid_output_layout = false;
256 layout output_layout = layout(data_types::f32, format::bfyx, tensor());
258 std::vector<program_node*> dependencies;
259 std::list<program_node*> users;
261 // list of primitives that can reuse same memory buffers due to execution order conflicts
262 std::set<primitive_id> memory_dependencies;
264 bool constant = false;
265 bool data_flow = false;
268 uint8_t user_mark = 0;
269 bool optimized = false;
270 bool share_buffer = true;
271 bool _support_padding = false;
273 mutable bool has_reused_memory = false;
274 mutable uint32_t reused_memory_color = 0;
276 const primitive_id org_id;
278 struct fused_activation_params
280 cldnn_activation_func activation_func = activation_none;
281 cldnn_activation_additional_params additional_params = { 0.0f, 0.0f };
284 fused_activation_params fused_activation;
286 void invalidate_users() const;
291 template <class PType>
292 struct api_typed_program_node_base : public program_node
294 static_assert(meta::is_api_primitive<PType>::value, "PType should name a non-const, non-volatile type derived from cldnn::primitive but not from cldnn::internal_primitive");
295 friend class cldnn::graph_initializations;
296 friend struct cldnn::program_impl;
297 friend class cldnn::reorder_inputs;
299 using program_node::program_node;
301 std::shared_ptr<const PType> get_primitive() const { return std::static_pointer_cast<const PType>(program_node::get_primitive()); }
304 std::shared_ptr<PType> typed_desc() const { return std::static_pointer_cast<PType>(desc); }
307 struct internal_program_node_base : public program_node
309 friend struct cldnn::program_impl;
311 internal_program_node_base(program_impl& prog);
313 const primitive_id& id() const override { return internal_id; }
315 void set_implementation(std::unique_ptr<primitive_impl>&& impl);
318 primitive_id internal_id;
320 static primitive_id get_next_internal_id();
323 template <class PType>
324 struct internal_typed_program_node_base : public internal_program_node_base
326 static_assert(meta::is_internal_primitive<PType>::value, "PType should name a non-const, non-volatile type derived from cldnn::internal_primitive");
329 using internal_program_node_base::internal_program_node_base;
331 primitive_type_id type() const override { return PType::type_id(); }
333 template <class... Guard>
335 void get_primitive(Guard&&...)
337 static_assert(meta::always_false<meta::pack<Guard...>>::value, "Trying to get primitive from internal node");
342 template <class... Guard>
344 void typed_desc(Guard&&...)
346 static_assert(meta::always_false<meta::pack<Guard...>>::value, "Trying to get primitive from internal node");
352 Template class used to indicate that usage context requires 'program_node' to wrap primitive
353 of type 'PType'. Successful conversion from 'program_node' to 'typed_program_node<PType>' means
354 that this restriction in fact holds and functions/method/etc. may saftly use uderlaying primitive.
356 This class shadows 'get_primitive' method from base class which now returns pointer to more specific
359 template <class PType>
360 using typed_program_node_base = typename std::conditional<meta::is_api_primitive<PType>::value, details::api_typed_program_node_base<PType>, details::internal_typed_program_node_base<PType>>::type;
363 Actual template class used in context which requires 'program_node' to wrap
364 primitive of type 'PType'. This class is introduced to provide possibility of explicit specialization.
365 In most cases such specializations would add accessors to make access to PType-specific fields easier.
367 It's not required to specialize this class for new primitives types.
369 template <class PType>
370 struct typed_program_node : public typed_program_node_base<PType>
372 using typed_program_node_base<PType>::typed_program_node_base;
374 program_node& input() const { return program_node::get_dependency(0); }