Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / program_node.h
1 /*
2 // Copyright (c) 2017 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 #pragma once
17
18 #include <set>
19
20 #include "api/CPP/primitive.hpp"
21 #include "internal_primitive.h"
22
23 #include "meta_utils.h"
24
25 namespace cldnn
26 {
27
28 struct program_impl;
29 class reorder_inputs;
30 class graph_initializations;
31
32 template <class T>
33 struct typed_program_node;
34
35 template <class PType>
36 struct internal_primitive_type_base;
37
38 class json_composite;
39 class xml_composite;
40
41 /*
42     Base class for all primitives which wraps API class and extends it to be used
43     in graph context.
44
45     Besides primitive description provided by user, this class includes functionality to
46     ask for direct predecessors and succesors as well as takes care of changes to primitive
47     which would affect other graph's nodes (the most commont case is probably calculating output layout).
48
49     At graph level, all connections between nodes are directly stored inside program_nodes - in oposite
50     to API level where all primitives store only ids of related ones.
51 */
52 struct program_node
53 {
54     friend struct program_impl;                     // to be removed when possible
55     friend class compile_graph;                     // to be removed when possible
56     friend class graph_initializations;             // to be removed when possible
57     friend class prepare_primitive_fusing;          // to be removed when possible
58     friend class prepare_conv_eltw_fusing;          // to be removed when possible
59     friend class prepare_conv_eltw_read_write_opt;  // to be removed when possible
60     friend class propagate_constants;               // to be removed when possible
61     friend class post_optimize_weights;             // to be removed when possible - requires an access to selected_impl
62
63     template <class PType>
64     friend struct typed_program_node;
65
66     program_node(std::shared_ptr<primitive> prim, program_impl& prog);
67
68     program_node(program_node const&) = delete;
69
70     virtual ~program_node() = default;
71
72 public:
73     virtual const primitive_id& id() const { return desc->id; }
74     virtual primitive_type_id type() const { return desc->type; }
75
76     template <class PType>
77     bool is_type() const
78     {
79         static_assert(meta::is_primitive<PType>::value, "Type argument for program_node::is_type should be a non-const, non-volatile type derived from primitive");
80         return type() == PType::type_id();
81     }
82
83     program_impl& get_program() { return myprog; }
84     program_impl const& get_program() const { return myprog; }
85
86     std::shared_ptr<primitive_impl> get_selected_impl() const { return selected_impl; }
87
88     std::vector<program_node*> const& get_dependencies() const { return dependencies; }
89     program_node& get_dependency(size_t idx) const { return *dependencies.at(idx); }
90
91     //replaces idx-th dependency of 'this' with 'new_dep', calls program::remove_if_dangling(old_dep)
92     void replace_dependency(size_t idx, program_node& new_dep);
93     //searches for 'old_dep' in dependencies list of 'this' and replaces it with 'new_dep', calls program::remove_if_dangling(old_dep)
94     void replace_dependency(program_node const& old_dep, program_node& new_dep);
95
96     std::vector<primitive_id> get_dependencies_ids() const;
97
98     void remove_dependency(size_t idx);
99     void remove_dependency(program_node& node);
100
101     std::set<primitive_id> get_memory_dependencies() const;
102     void add_memory_dependency(primitive_id);
103     void add_memory_dependency(std::vector<primitive_id>);
104
105     template<class PType>
106     bool have_user_with_type() const
107     {
108         for (auto const& usr : users)
109         {
110             if (usr->is_type<PType>()) return true;
111         }
112         return false;
113     }
114
115     bool is_detached(bool whole_branch = false);
116
117     std::list<program_node*> const& get_users() { return users; }
118     // for const method, add const to stored successors/predecessors
119     std::list<const program_node*> const& get_users() const { return reinterpret_cast<const std::list<const program_node*>&>(users); }
120
121     std::unique_ptr<json_composite> desc_to_json() const;
122     //do not modify primitive directly to keep synchronisation with graph
123     std::shared_ptr<const primitive> get_primitive() const { return desc; }
124     //primitive modification functions
125     void set_output_padding(padding const& padd)
126     {
127         //changing output padding shouldn't cause any changes to other primitives
128         //so just change it
129         output_layout.data_padding = padd;
130     }
131
132     void merge_output_padding(padding const& padd)
133     {
134         set_output_padding(padding::max(padd, output_layout.data_padding));
135     }
136
137     //only calculated output layout (for external usage), does not modify/use cached output layout nor invalidate users
138     layout calc_output_layout() const;
139
140     //uses cached output layout if valid, if not calls 'calc_output_layout' and stores its result + invalidate all users if layout has changed and @p invalidate_users_if_changed is set to true
141     layout get_output_layout(bool invalidate_users_if_changed = true);
142     //returns cached output layout if valid, otherwise throws an exception
143     layout get_output_layout() const;
144     //returns result of get_output_layout without padding
145     layout get_non_padded_output_layout(bool invalidate_users_if_changed = true);
146
147     //sets cached output layout to an arbitrary value, invalidates users if new layout differs from previous one and @p invalidate_users_if_changed is set to true
148     //returns whether output layout has changed
149     bool set_output_layout(layout new_layout, bool invalidate_users_if_changed = true);
150
151     //forces recalculation of cached output layout, invalidates users if new layout is different than previous one and @p invalidate_users_if_changed is set to true
152     //returns whether output layout has changed
153     bool recalc_output_layout(bool invalidate_users_if_changed = true);
154
155     bool is_padded() { return static_cast<bool>(get_output_layout().data_padding); }
156     bool is_padded() const { return static_cast<bool>(get_output_layout().data_padding); }
157
158     bool has_padded_dependency();
159     bool has_padded_dependency() const;
160
161     bool is_input() const { return dependencies.empty(); }
162     bool is_endpoint() const { return users.empty(); }
163     void set_output(bool out) { output = out; }
164     bool is_output() const { return output; }
165
166     bool is_valid_output_layout() const { return valid_output_layout; }
167
168     uint8_t mark(uint8_t val = 1) { uint8_t ret = user_mark; user_mark = val; return ret; }
169     void unmark() { user_mark = 0; }
170     bool is_marked() const { return user_mark != 0; }
171     bool is_marked(uint8_t val) const { return user_mark == val; }
172     uint8_t get_user_mark() const { return user_mark; }
173
174     void set_fused_activation(cldnn_activation_func activation_func, cldnn_activation_additional_params additional_params)
175     {
176         fused_activation.activation_func = activation_func;
177         fused_activation.additional_params = additional_params;
178     }
179
180     cldnn_activation_func get_fused_activation_func() const
181     {
182         return fused_activation.activation_func;
183     }
184
185     cldnn_activation_additional_params get_fused_activation_params() const
186     {
187         return fused_activation.additional_params;
188     }
189
190     // check/set if the node can be optimized out (removed from the network)
191     bool can_be_optimized() const { return optimized; }
192     void can_be_optimized(bool opt) { optimized = opt; }
193
194     // check/set if the node's buffer can be shared during the memory pool optimization
195     bool can_share_buffer() const { return share_buffer; }
196     void can_share_buffer(bool share) { share_buffer = share; }
197
198     // check/set if the node support padding in x,y,b and f
199     bool support_padding() const { return _support_padding; }
200     void support_padding(bool support) { _support_padding = support; }
201
202     primitive_id get_org_primitive_id() const { return org_id; }
203
204     bool is_constant() const { return constant; }
205     
206     // returns true if this node is within main data flow of the network (i.e. it does not describe helper data like convolution's weights etc.)
207     bool is_in_data_flow() const { return data_flow; }
208
209     //conversion from generic to specific
210     template <class To, class..., class = typename std::enable_if<!std::is_same<To, primitive>::value>::type>
211     typed_program_node<To>& as()
212     {
213         if (type() != To::type_id())
214             throw std::invalid_argument("program_node: mismatching primitive's type");
215
216         return reinterpret_cast<typed_program_node<To>&>(*this);
217     }
218
219     template <class To, class..., class = typename std::enable_if<!std::is_same<To, primitive>::value>::type>
220     typed_program_node<To> const& as() const
221     {
222         if (type() != To::type_id())
223             throw std::invalid_argument("program_node: mismatching primitive's type");
224
225         return reinterpret_cast<typed_program_node<To> const&>(*this);
226     }
227
228     template <class To>
229     operator typed_program_node<To>& ()
230     {
231         return as<To>();
232     }
233
234     template <class To>
235     operator typed_program_node<To> const& () const
236     {
237         return as<To>();
238     }
239
240     void set_reused_memory_color(uint32_t color) const
241     {
242         has_reused_memory = true;
243         reused_memory_color = color;
244     }
245
246     bool is_reusing_memory() { return has_reused_memory; };
247     uint32_t get_reused_memory_color() { return reused_memory_color; ; }
248
249 protected:
250     std::shared_ptr<primitive> desc;
251     program_impl& myprog;
252
253     std::shared_ptr<primitive_impl> selected_impl;
254
255     bool valid_output_layout = false;
256     layout output_layout = layout(data_types::f32, format::bfyx, tensor());
257
258     std::vector<program_node*> dependencies;
259     std::list<program_node*> users;
260
261     // list of primitives that can reuse same memory buffers due to execution order conflicts
262     std::set<primitive_id> memory_dependencies;
263
264     bool constant = false;
265     bool data_flow = false;
266
267     bool output = false;
268     uint8_t user_mark = 0;
269     bool optimized = false;
270     bool share_buffer = true;
271     bool _support_padding = false;
272
273     mutable bool has_reused_memory = false;
274     mutable uint32_t reused_memory_color = 0;
275
276     const primitive_id org_id;
277
278     struct fused_activation_params
279     {
280         cldnn_activation_func activation_func = activation_none;
281         cldnn_activation_additional_params additional_params = { 0.0f, 0.0f };
282     };
283
284     fused_activation_params fused_activation;
285
286     void invalidate_users() const;
287 };
288
289 namespace details
290 {
291     template <class PType>
292     struct api_typed_program_node_base : public program_node
293     {
294         static_assert(meta::is_api_primitive<PType>::value, "PType should name a non-const, non-volatile type derived from cldnn::primitive but not from cldnn::internal_primitive");
295         friend class cldnn::graph_initializations;
296         friend struct cldnn::program_impl;
297         friend class cldnn::reorder_inputs;
298     public:
299         using program_node::program_node;
300
301         std::shared_ptr<const PType> get_primitive() const { return std::static_pointer_cast<const PType>(program_node::get_primitive()); }
302
303     protected:
304         std::shared_ptr<PType> typed_desc() const { return std::static_pointer_cast<PType>(desc); }
305     };
306
307     struct internal_program_node_base : public program_node
308     {
309         friend struct cldnn::program_impl;
310
311         internal_program_node_base(program_impl& prog);
312
313         const primitive_id& id() const override { return internal_id; }
314
315         void set_implementation(std::unique_ptr<primitive_impl>&& impl);
316
317     private:
318         primitive_id internal_id;
319
320         static primitive_id get_next_internal_id();
321     };
322
323     template <class PType>
324     struct internal_typed_program_node_base : public internal_program_node_base
325     {
326         static_assert(meta::is_internal_primitive<PType>::value, "PType should name a non-const, non-volatile type derived from cldnn::internal_primitive");
327
328     public:
329         using internal_program_node_base::internal_program_node_base;
330
331         primitive_type_id type() const override { return PType::type_id(); }
332
333         template <class... Guard>
334         [[noreturn]]
335         void get_primitive(Guard&&...)
336         {
337             static_assert(meta::always_false<meta::pack<Guard...>>::value, "Trying to get primitive from internal node");
338         }
339
340
341     protected:
342         template <class... Guard>
343         [[noreturn]]
344         void typed_desc(Guard&&...)
345         {
346             static_assert(meta::always_false<meta::pack<Guard...>>::value, "Trying to get primitive from internal node");
347         }
348     };
349 }
350
351 /*
352 Template class used to indicate that usage context requires 'program_node' to wrap primitive
353 of type 'PType'. Successful conversion from 'program_node' to 'typed_program_node<PType>' means
354 that this restriction in fact holds and functions/method/etc. may saftly use uderlaying primitive.
355
356 This class shadows 'get_primitive' method from base class which now returns pointer to more specific
357 type.
358 */
359 template <class PType>
360 using typed_program_node_base = typename std::conditional<meta::is_api_primitive<PType>::value, details::api_typed_program_node_base<PType>, details::internal_typed_program_node_base<PType>>::type;
361
362 /*
363     Actual template class used in context which requires 'program_node' to wrap
364     primitive of type 'PType'. This class is introduced to provide possibility of explicit specialization.
365     In most cases such specializations would add accessors to make access to PType-specific fields easier.
366
367     It's not required to specialize this class for new primitives types.
368 */
369 template <class PType>
370 struct typed_program_node : public typed_program_node_base<PType>
371 {
372     using typed_program_node_base<PType>::typed_program_node_base;
373
374     program_node& input() const { return program_node::get_dependency(0); }
375 };
376
377 }