Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / constants_propagator.cpp
1 /*
2 // Copyright (c) 2017 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "constants_propagator.h"
18 #include "engine_impl.h"
19 #include "program_impl.h"
20 #include "network_impl.h"
21 #include "memory_impl.h"
22
23 #include "api/CPP/input_layout.hpp"
24
25 using namespace cldnn;
26
27 constants_propagator::constants_propagator(program_impl::ptr program) : prog(program)
28 {
29 }
30
31 void constants_propagator::visit_node(program_node& node)
32 {
33     if (node.is_constant())
34         handle_constant(node);
35 }
36
37 std::list<std::pair<primitive_id, memory_impl::ptr>> constants_propagator::calculate()
38 {
39     if (!has_non_trivial_constants)
40         return{};
41
42     build_options bo;
43     bo.set_option(build_option::optimize_data(false));
44     bo.set_option(build_option::outputs(const_outputs));
45     network_impl::ptr net = prog->get_engine().build_network(tpl, bo, true);
46     for (auto& cin : const_inputs)
47         net->set_input_data(cin->id(), cin->get_attached_memory());
48
49     net->execute({});
50     net->reset_execution(true); //wait for computations to complete
51     auto outputs = net->get_outputs();
52
53     std::list<std::pair<primitive_id, memory_impl::ptr>> ret;
54     for (auto& out : outputs)
55         ret.push_back({ out->id(), &out->output_memory() });
56
57     return ret;
58 }
59
60 void constants_propagator::handle_constant(program_node& node)
61 {
62     if (!node.is_type<data>())
63     {
64         add_constant(node);
65         if (node.has_non_const_user())
66             const_outputs.push_back(node.id());
67     }
68 }
69
70 void constants_propagator::add_constant(program_node& node)
71 {
72     if (node.is_type<data>())
73         return;
74
75     tpl.add(node.desc);
76     has_non_trivial_constants = true;
77
78     //if a node is either an endpoint or an output, always add it as an output
79     if (node.is_endpoint() || node.is_output())
80         const_outputs.push_back(node.id());
81
82     //if a non-tirivial constant has a trivial input, add this input as an input for our network
83     add_deps_to_tpl(node.get_dependencies());
84 }
85
86 void constants_propagator::add_deps_to_tpl(const std::vector<program_node*>& deps)
87 {
88      /*   
89         Nodes can share dependencies, if we already have dep in tpl, don't add it again.
90         example:          
91             C   <--- shared dep
92            / \
93           /   \
94          A     B
95      */
96     for (auto& dep : deps)
97     {
98         if (dep->is_type<data>())
99         {
100             if (is_already_in_tpl(dep->id())) continue;
101             tpl.add(std::make_shared<input_layout>(dep->id(), dep->as<data>().get_primitive()->mem.get_layout()));
102             const_inputs.push_back(&dep->as<data>());
103         }
104     }
105 }
106
107 bool constants_propagator::is_already_in_tpl(const primitive_id& id)
108 {
109     for (auto const& id_in_tpl : tpl.get_primitives_id())
110     {
111         if (id == id_in_tpl) return true;
112     }
113     return false;
114 }